@@ -135,7 +135,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
135135 x_Y = self .Y_Conv (x [:, 0 , ])
136136 x_Cb = self .Cb_Conv (x [:, 1 , ])
137137 x_Cr = self .Cr_Conv (x [:, 2 , ])
138- x = torch .cat ([x_Y , x_Cb , x_Cr ], axis = 1 )
138+ x = torch .cat ([x_Y , x_Cb , x_Cr ], dim = 1 )
139139 return x
140140
141141
@@ -160,12 +160,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
160160
161161 mean_list = torch .zeros ([3 , 64 ])
162162 var_list = torch .zeros ([3 , 64 ])
163- mean_list [0 ] = torch .mean (x [:, 0 , ], axis = 0 )
164- mean_list [1 ] = torch .mean (x [:, 1 , ], axis = 0 )
165- mean_list [2 ] = torch .mean (x [:, 2 , ], axis = 0 )
166- var_list [0 ] = torch .var (x [:, 0 , ], axis = 0 )
167- var_list [1 ] = torch .var (x [:, 1 , ], axis = 0 )
168- var_list [2 ] = torch .var (x [:, 2 , ], axis = 0 )
163+ mean_list [0 ] = torch .mean (x [:, 0 , ], dim = 0 )
164+ mean_list [1 ] = torch .mean (x [:, 1 , ], dim = 0 )
165+ mean_list [2 ] = torch .mean (x [:, 2 , ], dim = 0 )
166+ var_list [0 ] = torch .var (x [:, 0 , ], dim = 0 )
167+ var_list [1 ] = torch .var (x [:, 1 , ], dim = 0 )
168+ var_list [2 ] = torch .var (x [:, 2 , ], dim = 0 )
169169 return mean_list , var_list
170170
171171
@@ -470,6 +470,11 @@ def forward(self, x):
470470 x = self .forward_head (x )
471471 return x
472472
473+
474+ # --- Components like LayerNorm, GRN, DropPath, FeedForward, PosCNN, trunc_normal_ ---
475+ # (이 부분은 einops와 무관하므로 위 코드와 동일하게 유지합니다. 여기서는 공간 절약을 위해 생략)
476+ # 기존 코드의 LayerNorm, GRN, DropPath, FeedForward, PosCNN, trunc_normal_ 함수를 그대로 사용하세요.
477+
473478class LayerNorm (nn .Module ):
474479 def __init__ (self , normalized_shape , eps = 1e-6 , data_format = "channels_last" ):
475480 super ().__init__ ()
@@ -604,5 +609,4 @@ def csatv2(pretrained: bool = False, **kwargs) -> CSATv2:
604609 img_size = kwargs .pop ('img_size' , 512 ),
605610 num_classes = kwargs .pop ('num_classes' , 1000 ),
606611 )
607-
608612 return _create_csatv2 ('csatv2' , pretrained , ** dict (model_args , ** kwargs ))
0 commit comments