Skip to content

Commit 09d3d25

Browse files
authored
Update csatv2.py
1 parent 4f03468 commit 09d3d25

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

timm/models/csatv2.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
135135
x_Y = self.Y_Conv(x[:, 0, ])
136136
x_Cb = self.Cb_Conv(x[:, 1, ])
137137
x_Cr = self.Cr_Conv(x[:, 2, ])
138-
x = torch.cat([x_Y, x_Cb, x_Cr], axis=1)
138+
x = torch.cat([x_Y, x_Cb, x_Cr], dim=1)
139139
return x
140140

141141

@@ -160,12 +160,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
160160

161161
mean_list = torch.zeros([3, 64])
162162
var_list = torch.zeros([3, 64])
163-
mean_list[0] = torch.mean(x[:, 0, ], axis=0)
164-
mean_list[1] = torch.mean(x[:, 1, ], axis=0)
165-
mean_list[2] = torch.mean(x[:, 2, ], axis=0)
166-
var_list[0] = torch.var(x[:, 0, ], axis=0)
167-
var_list[1] = torch.var(x[:, 1, ], axis=0)
168-
var_list[2] = torch.var(x[:, 2, ], axis=0)
163+
mean_list[0] = torch.mean(x[:, 0, ], dim=0)
164+
mean_list[1] = torch.mean(x[:, 1, ], dim=0)
165+
mean_list[2] = torch.mean(x[:, 2, ], dim=0)
166+
var_list[0] = torch.var(x[:, 0, ], dim=0)
167+
var_list[1] = torch.var(x[:, 1, ], dim=0)
168+
var_list[2] = torch.var(x[:, 2, ], dim=0)
169169
return mean_list, var_list
170170

171171

@@ -470,6 +470,11 @@ def forward(self, x):
470470
x = self.forward_head(x)
471471
return x
472472

473+
474+
# --- Components like LayerNorm, GRN, DropPath, FeedForward, PosCNN, trunc_normal_ ---
475+
# (이 부분은 einops와 무관하므로 위 코드와 동일하게 유지합니다. 여기서는 공간 절약을 위해 생략)
476+
# 기존 코드의 LayerNorm, GRN, DropPath, FeedForward, PosCNN, trunc_normal_ 함수를 그대로 사용하세요.
477+
473478
class LayerNorm(nn.Module):
474479
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
475480
super().__init__()
@@ -604,5 +609,4 @@ def csatv2(pretrained: bool = False, **kwargs) -> CSATv2:
604609
img_size=kwargs.pop('img_size', 512),
605610
num_classes=kwargs.pop('num_classes', 1000),
606611
)
607-
608612
return _create_csatv2('csatv2', pretrained, **dict(model_args, **kwargs))

0 commit comments

Comments
 (0)