@@ -46,8 +46,6 @@ def __init__(self, embedding_dim, hidden_dim, prio_accuracy=False, context=None)
4646 super ().__init__ (context = context )
4747
4848 def set_up_artifacts (self ):
49- # Artifact setup
50- # ---
5149 artifacts = []
5250 device_str = self .context .device_manager .device_str ()
5351
@@ -57,6 +55,7 @@ def set_up_artifacts(self):
5755 num_aie_columns = 8 ,
5856 tile_size = 1 ,
5957 )
58+ self .gemv_1 = gemv_1
6059 gemv_1_xclbin , gemv_1_insts = gemv_1 .get_artifacts (
6160 prefix = "swiglu_decode_gemv_1_"
6261 )
@@ -75,6 +74,8 @@ def set_up_artifacts(self):
7574 num_channels = 2 ,
7675 tile_size = self .hidden_dim // 16 ,
7776 )
77+ self .silu = silu
78+ self .hidden_dim_padded = silu .size
7879 silu_xclbin , silu_insts = silu .get_artifacts (prefix = "swiglu_decode_silu_" )
7980 silu_xclbin .xclbin_input = gemv_1_xclbin
8081 silu_xclbin .extra_flags += [
@@ -91,6 +92,8 @@ def set_up_artifacts(self):
9192 num_channels = 2 ,
9293 tile_size = self .hidden_dim // 8 ,
9394 )
95+ self .eltwise_mul = eltwise_mul
96+ assert self .hidden_dim <= eltwise_mul .size <= self .hidden_dim_padded
9497 eltwise_mul_xclbin , eltwise_mul_insts = eltwise_mul .get_artifacts (
9598 prefix = "swiglu_decode_eltwise_mul_"
9699 )
@@ -109,6 +112,7 @@ def set_up_artifacts(self):
109112 num_aie_columns = 8 ,
110113 tile_size = 1 ,
111114 )
115+ self .gemv_2 = gemv_2
112116 gemv_2_xclbin , gemv_2_insts = gemv_2 .get_artifacts (
113117 prefix = "swiglu_decode_gemv_2_"
114118 )
@@ -135,28 +139,26 @@ def set_up_artifacts(self):
135139 self .add_artifacts (artifacts )
136140
137141 def set_up_runtime (self ):
138- # Runtime setup
139- # ---
140142 self .add_buffer ("input" , self .embedding_dim )
141143 self .add_buffer (
142144 "weights_1" ,
143- self .embedding_dim * self .hidden_dim ,
145+ self .embedding_dim * self .hidden_dim_padded ,
144146 static_data = torch_to_numpy (self .weights_1 ),
145147 )
146148 self .add_buffer (
147149 "weights_2" ,
148- self .embedding_dim * self .hidden_dim ,
150+ self .embedding_dim * self .hidden_dim_padded ,
149151 static_data = torch_to_numpy (self .weights_2 ),
150152 )
151153 self .add_buffer (
152154 "weights_3" ,
153- self .hidden_dim * self .embedding_dim ,
155+ self .hidden_dim_padded * self .embedding_dim ,
154156 static_data = torch_to_numpy (self .weights_3 ),
155157 )
156- self .add_buffer ("left" , self .hidden_dim )
157- self .add_buffer ("left_swished" , self .hidden_dim )
158- self .add_buffer ("right" , self .hidden_dim )
159- self .add_buffer ("intermediate" , self .hidden_dim )
158+ self .add_buffer ("left" , self .hidden_dim_padded )
159+ self .add_buffer ("left_swished" , self .hidden_dim_padded )
160+ self .add_buffer ("right" , self .hidden_dim_padded )
161+ self .add_buffer ("intermediate" , self .hidden_dim_padded )
160162 self .add_buffer ("output" , self .embedding_dim )
161163 self .add_kernel (
162164 "swiglu_gemv_1" ,
@@ -191,9 +193,7 @@ def set_up_runtime(self):
191193 self .add_to_runlist ("swiglu_gemv_2" , "weights_3" , "intermediate" , "output" )
192194
193195 def forward (self , x ):
194- # Turn into a numpy vector and drop the batch and other higher dimensions, if any; will error if batch or other higher dimensions > 1
195196 x_flat = x .reshape (x .shape [- 1 ])
196-
197197 assert x_flat .shape [0 ] == self .embedding_dim
198198
199199 self .write_buffer ("input" , x_flat )
0 commit comments