33
44import ttnn
55from models .common .lightweightmodule import LightweightModule
6- import torch
76from models .demos .deepseek_v3 .utils .config_helpers import matmul_config
87
98
@@ -27,6 +26,11 @@ def __init__(
2726 self .norm_weight = parameters ["norm" ]["weight" ]
2827 self .norm_bias = parameters ["norm" ]["bias" ]
2928
29+ self .kernel_top_left = parameters ["conv_kernels" ]["top_left" ]
30+ self .kernel_bottom_left = parameters ["conv_kernels" ]["bottom_left" ]
31+ self .kernel_top_right = parameters ["conv_kernels" ]["top_right" ]
32+ self .kernel_bottom_right = parameters ["conv_kernels" ]["bottom_right" ]
33+
3034 def forward (self , input_tensor ):
3135 """
3236 Args:
@@ -44,96 +48,29 @@ def forward(self, input_tensor):
4448 input_tensor = ttnn .reshape (input_tensor , (B , H , W , C ))
4549 x = ttnn .to_layout (input_tensor , ttnn .ROW_MAJOR_LAYOUT , memory_config = ttnn .L1_MEMORY_CONFIG )
4650
47- kernel_top_left = torch .zeros (C , 1 , 2 , 2 , dtype = torch .bfloat16 )
48- kernel_top_left [:, 0 , 0 , 0 ] = 1.0
49-
50- kernel_bottom_left = torch .zeros (C , 1 , 2 , 2 , dtype = torch .bfloat16 )
51- kernel_bottom_left [:, 0 , 1 , 0 ] = 1.0
52-
53- kernel_top_right = torch .zeros (C , 1 , 2 , 2 , dtype = torch .bfloat16 )
54- kernel_top_right [:, 0 , 0 , 1 ] = 1.0
55-
56- kernel_bottom_right = torch .zeros (C , 1 , 2 , 2 , dtype = torch .bfloat16 )
57- kernel_bottom_right [:, 0 , 1 , 1 ] = 1.0
58-
59- # Convert to TTNN tensors
60- tt_kernel_top_left = ttnn .from_torch (kernel_top_left , device = self .device )
61- tt_kernel_bottom_left = ttnn .from_torch (kernel_bottom_left , device = self .device )
62- tt_kernel_top_right = ttnn .from_torch (kernel_top_right , device = self .device )
63- tt_kernel_bottom_right = ttnn .from_torch (kernel_bottom_right , device = self .device )
64-
65- # Apply grouped convolutions for each patch
66- x0 = ttnn .conv2d (
67- input_tensor = x ,
68- weight_tensor = tt_kernel_top_left ,
69- in_channels = C ,
70- out_channels = C ,
71- device = self .device ,
72- kernel_size = (2 , 2 ),
73- stride = (2 , 2 ),
74- padding = (0 , 0 ),
75- groups = C , # Grouped convolution
76- batch_size = B ,
77- input_height = H ,
78- input_width = W ,
79- conv_config = None ,
80- dtype = ttnn .bfloat16 ,
81- memory_config = ttnn .DRAM_MEMORY_CONFIG ,
82- )
83-
84- x1 = ttnn .conv2d (
85- input_tensor = x ,
86- weight_tensor = tt_kernel_bottom_left ,
87- in_channels = C ,
88- out_channels = C ,
89- device = self .device ,
90- kernel_size = (2 , 2 ),
91- stride = (2 , 2 ),
92- padding = (0 , 0 ),
93- groups = C ,
94- batch_size = B ,
95- input_height = H ,
96- input_width = W ,
97- conv_config = None ,
98- dtype = ttnn .bfloat16 ,
99- memory_config = ttnn .DRAM_MEMORY_CONFIG ,
100- )
101-
102- x2 = ttnn .conv2d (
103- input_tensor = x ,
104- weight_tensor = tt_kernel_top_right ,
105- in_channels = C ,
106- out_channels = C ,
107- device = self .device ,
108- kernel_size = (2 , 2 ),
109- stride = (2 , 2 ),
110- padding = (0 , 0 ),
111- groups = C ,
112- batch_size = B ,
113- input_height = H ,
114- input_width = W ,
115- conv_config = None ,
116- dtype = ttnn .bfloat16 ,
117- memory_config = ttnn .DRAM_MEMORY_CONFIG ,
118- )
119-
120- x3 = ttnn .conv2d (
121- input_tensor = x ,
122- weight_tensor = tt_kernel_bottom_right ,
123- in_channels = C ,
124- out_channels = C ,
125- device = self .device ,
126- kernel_size = (2 , 2 ),
127- stride = (2 , 2 ),
128- padding = (0 , 0 ),
129- groups = C ,
130- batch_size = B ,
131- input_height = H ,
132- input_width = W ,
133- conv_config = None ,
134- dtype = ttnn .bfloat16 ,
135- memory_config = ttnn .DRAM_MEMORY_CONFIG ,
136- )
51+ # Common convolution parameters
52+ conv_params = {
53+ "input_tensor" : x ,
54+ "in_channels" : C ,
55+ "out_channels" : C ,
56+ "device" : self .device ,
57+ "kernel_size" : (2 , 2 ),
58+ "stride" : (2 , 2 ),
59+ "padding" : (0 , 0 ),
60+ "groups" : C , # Grouped convolution
61+ "batch_size" : B ,
62+ "input_height" : H ,
63+ "input_width" : W ,
64+ "conv_config" : None ,
65+ "dtype" : ttnn .bfloat16 ,
66+ "memory_config" : ttnn .DRAM_MEMORY_CONFIG ,
67+ }
68+
69+ # Apply grouped convolutions for each patch, this is instead of a slice operation
70+ x0 = ttnn .conv2d (weight_tensor = self .kernel_top_left , ** conv_params )
71+ x1 = ttnn .conv2d (weight_tensor = self .kernel_bottom_left , ** conv_params )
72+ x2 = ttnn .conv2d (weight_tensor = self .kernel_top_right , ** conv_params )
73+ x3 = ttnn .conv2d (weight_tensor = self .kernel_bottom_right , ** conv_params )
13774
13875 ttnn .deallocate (x )
13976
0 commit comments