File tree Expand file tree Collapse file tree 1 file changed +6
-0
lines changed
Expand file tree Collapse file tree 1 file changed +6
-0
lines changed Original file line number Diff line number Diff line change @@ -171,7 +171,10 @@ def forward_orig(
171171 pe = None
172172
173173 blocks_replace = patches_replace .get ("dit" , {})
174+ transformer_options ["total_blocks" ] = len (self .double_blocks )
175+ transformer_options ["block_type" ] = "double"
174176 for i , block in enumerate (self .double_blocks ):
177+ transformer_options ["block_index" ] = i
175178 if ("double_block" , i ) in blocks_replace :
176179 def block_wrap (args ):
177180 out = {}
@@ -215,7 +218,10 @@ def block_wrap(args):
215218 if self .params .global_modulation :
216219 vec , _ = self .single_stream_modulation (vec_orig )
217220
221+ transformer_options ["total_blocks" ] = len (self .single_blocks )
222+ transformer_options ["block_type" ] = "single"
218223 for i , block in enumerate (self .single_blocks ):
224+ transformer_options ["block_index" ] = i
219225 if ("single_block" , i ) in blocks_replace :
220226 def block_wrap (args ):
221227 out = {}
You can’t perform that action at this time.
0 commit comments