@@ -102,6 +102,25 @@ def codegen_store(
102
102
) -> ast .AST :
103
103
indexing = SubscriptIndexing .create (state , fake_tensor , subscript , extra_mask )
104
104
name = state .device_function .tensor_arg (fake_tensor ).name
105
+
106
+ # Check if value is a tensor load (Name node with id matching a tensor arg)
107
+ if isinstance (value , ast .Name ) and hasattr (state .device_function , '_tensor_args' ):
108
+ # Check if this name corresponds to a tensor argument
109
+ for tensor , tensor_arg in state .device_function ._tensor_args .items ():
110
+ if tensor_arg .name == value .id :
111
+ # This is a tensor value, we need to load from it
112
+ # Get the shape of the slice we're storing to
113
+ output_shape = SubscriptIndexing .compute_shape (fake_tensor , subscript )
114
+ if len (output_shape ) == 1 and tensor .ndim == 1 :
115
+ # Load the entire 1D tensor
116
+ value_indexing = SubscriptIndexing .create (state , tensor , [slice (None )], None )
117
+ value = expr_from_string (
118
+ f"tl.load({ value .id } + offset, mask)" ,
119
+ offset = value_indexing .index_expr ,
120
+ mask = value_indexing .mask_expr ,
121
+ )
122
+ break
123
+
105
124
return expr_from_string (
106
125
f"tl.store({ name } + offset, value, mask)" ,
107
126
value = value ,
@@ -511,7 +530,14 @@ def compute_shape(
511
530
output_size .extend (k .size ())
512
531
else :
513
532
raise exc .InvalidIndexingType (k )
514
- assert len (input_size ) == 0 , "invalid subscript"
533
+ # For partial indexing, append remaining dimensions to output
534
+ while input_size :
535
+ size = input_size .popleft ()
536
+ if size != 1 :
537
+ rdim = env .allocate_reduction_dimension (size )
538
+ output_size .append (rdim .var )
539
+ else :
540
+ output_size .append (1 )
515
541
return output_size
516
542
517
543
@staticmethod
@@ -648,6 +674,22 @@ def create(
648
674
)
649
675
else :
650
676
raise exc .InvalidIndexingType (type (k ))
677
+
678
+ # Handle remaining dimensions for partial indexing
679
+ while len (index_values ) < fake_value .ndim :
680
+ expand = tile_strategy .expand_str (output_size , output_idx )
681
+ size = fake_value .size (len (index_values ))
682
+ if size != 1 :
683
+ rdim = env .allocate_reduction_dimension (size )
684
+ block_idx = rdim .block_id
685
+ index_var = state .codegen .index_var (block_idx )
686
+ index_values .append (f"({ index_var } ){ expand } " )
687
+ if mask := state .codegen .mask_var (block_idx ):
688
+ mask_values .setdefault (f"({ mask } ){ expand } " )
689
+ else :
690
+ index_values .append (f"tl.zeros([1], { dtype } ){ expand } " )
691
+ output_idx += 1
692
+
651
693
assert len (output_size ) == output_idx
652
694
assert len (index_values ) == fake_value .ndim
653
695
index_expr = []
0 commit comments