Skip to content

Commit 2fa5756

Browse files
ign-amanksign-krishnanand
authored andcommitted
Refactors TS
1 parent 6958b5c commit 2fa5756

21 files changed

+72
-37
lines changed

models/experimental/SSR/tests/test_mlp.py renamed to models/experimental/SSR/tests/common/test_mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from loguru import logger
77

88
from models.experimental.SSR.reference.SSR.model.net_blocks import Mlp
9-
from models.experimental.SSR.tt import TTMlp
9+
from models.experimental.SSR.tt.common import TTMlp
1010

1111
from ttnn.model_preprocessing import preprocess_model_parameters, preprocess_linear_bias, preprocess_linear_weight
1212
from models.utility_functions import (

models/experimental/SSR/tests/test_basic_block.py renamed to models/experimental/SSR/tests/tile_selection/test_basic_block.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,16 @@
66
from loguru import logger
77

88
import ttnn
9-
from models.experimental.SSR.tt import TTBasicLayer
10-
from models.experimental.SSR.tt.patch_merging import TTPatchMerging
9+
from models.experimental.SSR.tt.tile_selection import TTBasicLayer, TTPatchMerging
1110
from ttnn.model_preprocessing import preprocess_model_parameters
1211
from models.utility_functions import tt2torch_tensor, comp_pcc
1312

14-
from models.experimental.SSR.tests.test_swin_transformer_block import create_swin_transformer_block_preprocessor
15-
from models.experimental.SSR.tests.test_patch_merging import create_patch_merging_preprocessor
13+
from models.experimental.SSR.tests.tile_selection.test_swin_transformer_block import (
14+
create_swin_transformer_block_preprocessor,
15+
)
16+
from models.experimental.SSR.tests.tile_selection.test_patch_merging import create_patch_merging_preprocessor
1617
from models.experimental.SSR.reference.SSR.model.net_blocks import PatchMerging, BasicLayer
1718

18-
import collections
19-
20-
21-
def to_2tuple(x):
22-
if isinstance(x, collections.abc.Iterable):
23-
return x
24-
return (x, x)
25-
2619

2720
def create_basic_layer_preprocessor(device, dim):
2821
def custom_preprocessor(torch_model, name, ttnn_module_args):

models/experimental/SSR/tests/test_mask_token_inference.py renamed to models/experimental/SSR/tests/tile_selection/test_mask_token_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from loguru import logger
55
from ttnn.model_preprocessing import preprocess_model_parameters, preprocess_linear_bias, preprocess_linear_weight
66
from models.experimental.SSR.reference.SSR.model.tile_selection import mask_token_inference
7-
from models.experimental.SSR.tt.mask_token_inference import TTMaskTokenInference
7+
from models.experimental.SSR.tt.tile_selection import TTMaskTokenInference
88

99
from models.utility_functions import tt2torch_tensor, comp_pcc
1010

models/experimental/SSR/tests/test_patch_embed.py renamed to models/experimental/SSR/tests/tile_selection/test_patch_embed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
# Fix the imports based on the codebase patterns
77
from models.experimental.SSR.reference.SSR.model.net_blocks import PatchEmbed
8-
from models.experimental.SSR.tt.patch_embed import TTPatchEmbed # Updated path
8+
from models.experimental.SSR.tt.tile_selection import TTPatchEmbed # Updated path
99
from ttnn.model_preprocessing import preprocess_model_parameters
1010
from models.utility_functions import tt2torch_tensor, comp_pcc
1111

models/experimental/SSR/tests/test_patch_merging.py renamed to models/experimental/SSR/tests/tile_selection/test_patch_merging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from models.experimental.SSR.reference.SSR.model.net_blocks import PatchMerging
1010
from ttnn.model_preprocessing import preprocess_model_parameters
1111
from models.utility_functions import tt2torch_tensor, comp_pcc
12-
from models.experimental.SSR.tt.patch_merging import TTPatchMerging
12+
from models.experimental.SSR.tt.tile_selection import TTPatchMerging
1313

1414

1515
def create_patch_merging_preprocessor(device, dim):

models/experimental/SSR/tests/test_swin_transformer_block.py renamed to models/experimental/SSR/tests/tile_selection/test_swin_transformer_block.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
from loguru import logger
55

66
from models.experimental.SSR.reference.SSR.model.net_blocks import SwinTransformerBlock
7-
from models.experimental.SSR.tt import TTSwinTransformerBlock
8-
from models.experimental.SSR.tests.test_mlp import create_mlp_preprocessor
9-
from models.experimental.SSR.tests.test_window_attn import create_window_attention_preprocessor
7+
from models.experimental.SSR.tt.tile_selection import TTSwinTransformerBlock
8+
from models.experimental.SSR.tests.common.test_mlp import create_mlp_preprocessor
9+
from models.experimental.SSR.tests.tile_selection.test_window_attn import create_window_attention_preprocessor
1010
from ttnn.model_preprocessing import preprocess_model_parameters
1111

1212
from models.utility_functions import (
@@ -71,6 +71,7 @@ def custom_preprocessor(torch_model, name, ttnn_module_args):
7171
(3, 8, 8, 96, 3, 7, 3, 4.0),
7272
),
7373
)
74+
@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}])
7475
def test_swin_transformer_block(device, batch_size, height, width, dim, num_heads, window_size, shift_size, mlp_ratio):
7576
# Create input tensor
7677
input_shape = (batch_size, height * width, dim)

models/experimental/SSR/tests/test_tile_selection.py renamed to models/experimental/SSR/tests/tile_selection/test_tile_selection.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
from ttnn.model_preprocessing import preprocess_model_parameters, preprocess_linear_bias, preprocess_linear_weight
77
from models.experimental.SSR.reference.SSR.model.tile_selection import TileSelection
88
from models.experimental.SSR.tt.tile_selection import TTTileSelection
9-
from models.experimental.SSR.tests.test_patch_embed import create_patch_embed_preprocessor
10-
from models.experimental.SSR.tests.test_basic_block import create_basic_layer_preprocessor
11-
from models.experimental.SSR.tests.test_mlp import create_mlp_preprocessor
12-
from models.experimental.SSR.tests.test_mask_token_inference import create_mask_token_inference_preprocessor
9+
from models.experimental.SSR.tests.tile_selection.test_patch_embed import create_patch_embed_preprocessor
10+
from models.experimental.SSR.tests.tile_selection.test_basic_block import create_basic_layer_preprocessor
11+
from models.experimental.SSR.tests.common.test_mlp import create_mlp_preprocessor
12+
from models.experimental.SSR.tests.tile_selection.test_mask_token_inference import (
13+
create_mask_token_inference_preprocessor,
14+
)
1315
from models.utility_functions import tt2torch_tensor, comp_pcc
1416

1517

models/experimental/SSR/tests/test_window_attn.py renamed to models/experimental/SSR/tests/tile_selection/test_window_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from models.experimental.SSR.reference.SSR.model.net_blocks import WindowAttention
99
from timm.models.layers import to_2tuple
1010

11-
from models.experimental.SSR.tt import TTWindowAttention
11+
from models.experimental.SSR.tt.tile_selection import TTWindowAttention
1212
from ttnn.model_preprocessing import preprocess_model_parameters, preprocess_linear_bias, preprocess_linear_weight
1313
from models.utility_functions import (
1414
tt2torch_tensor,

models/experimental/SSR/tt/__init__.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
from .mlp import TTMlp
2-
from .window_attn import TTWindowAttention
3-
from .swin_transformer_block import TTSwinTransformerBlock
4-
from .patch_embed import TTPatchEmbed
5-
from .patch_merging import TTPatchMerging
6-
from .basic_block import TTBasicLayer
7-
from .mask_token_inference import TTMaskTokenInference
1+
from .common import TTMlp
2+
from .tile_selection import (
3+
TTWindowAttention,
4+
TTSwinTransformerBlock,
5+
TTPatchEmbed,
6+
TTPatchMerging,
7+
TTBasicLayer,
8+
TTMaskTokenInference,
9+
)
810
from .patch_unembed import TTPatchUnEmbed
911
from .window_attn_tr import TTWindowAttentionTR
1012

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .mlp import TTMlp
2+
3+
__all__ = ["TTMlp"]

0 commit comments

Comments
 (0)