@@ -84,9 +84,45 @@ def _(state: CodegenState) -> ast.AST:
84
84
)
85
85
86
86
87
+ @_decorators .ref (store )
88
+ def _ (
89
+ tensor : torch .Tensor ,
90
+ index : list [object ],
91
+ value : torch .Tensor | torch .SymInt | float ,
92
+ extra_mask : torch .Tensor | None = None ,
93
+ ) -> None :
94
+ # Convert index list to tuple for tensor indexing
95
+ index_tuple = tuple (index )
96
+
97
+ # Apply extra mask if provided
98
+ if extra_mask is not None :
99
+ # Only store where the mask is True
100
+ if isinstance (value , torch .Tensor ):
101
+ tensor [index_tuple ] = torch .where (extra_mask , value , tensor [index_tuple ]) # pyright: ignore[reportArgumentType]
102
+ else :
103
+ # For scalar values, we need to create a tensor of the right shape
104
+ current = tensor [index_tuple ] # pyright: ignore[reportArgumentType]
105
+ # Cast value to a proper numeric type for full_like
106
+ if isinstance (value , torch .SymInt ):
107
+ numeric_value = int (value )
108
+ else :
109
+ numeric_value = value
110
+ tensor [index_tuple ] = torch .where ( # pyright: ignore[reportArgumentType]
111
+ extra_mask , torch .full_like (current , numeric_value ), current
112
+ )
113
+ else :
114
+ # Handle SymInt case for assignment
115
+ if isinstance (value , torch .SymInt ):
116
+ tensor [index_tuple ] = int (value ) # pyright: ignore[reportArgumentType]
117
+ else :
118
+ tensor [index_tuple ] = value # pyright: ignore[reportArgumentType]
119
+
120
+
87
121
@_decorators .api (tiles_as_sizes = True , allow_host_tensor = True )
88
122
def load (
89
- tensor : torch .Tensor , index : list [object ], extra_mask : torch .Tensor | None = None
123
+ tensor : torch .Tensor ,
124
+ index : list [object ],
125
+ extra_mask : torch .Tensor | None = None ,
90
126
) -> torch .Tensor :
91
127
"""Load a value from a tensor using a list of indices.
92
128
@@ -129,6 +165,83 @@ def _(node: torch.fx.Node) -> int:
129
165
return 0 # loads are always masked to 0
130
166
131
167
168
+ @_decorators .ref (load )
169
+ def _ (
170
+ tensor : torch .Tensor ,
171
+ index : list [object ],
172
+ extra_mask : torch .Tensor | None = None ,
173
+ ) -> torch .Tensor :
174
+ from .ref_tile import RefTile
175
+
176
+ if extra_mask is None :
177
+ return tensor [tuple (index )] # pyright: ignore[reportArgumentType]
178
+
179
+ # Create zero result matching mask shape
180
+ result = torch .zeros (extra_mask .shape , dtype = tensor .dtype , device = tensor .device )
181
+
182
+ # Process indices: convert RefTiles and clamp tensor indices
183
+ orig_indices , safe_indices , is_tensor_mask = [], [], []
184
+ for i , idx in enumerate (index ):
185
+ if isinstance (idx , RefTile ):
186
+ idx = idx .index # Convert RefTile to tensor
187
+
188
+ if isinstance (idx , torch .Tensor ):
189
+ dim_size = tensor .shape [i ] if i < len (tensor .shape ) else tensor .numel ()
190
+ orig_indices .append (idx )
191
+ safe_indices .append (torch .clamp (idx , 0 , dim_size - 1 ))
192
+ is_tensor_mask .append (True )
193
+ else :
194
+ orig_indices .append (idx )
195
+ safe_indices .append (idx )
196
+ is_tensor_mask .append (False )
197
+
198
+ # Apply broadcasting if we have multiple tensor indices
199
+ tensor_positions = [i for i , is_tensor in enumerate (is_tensor_mask ) if is_tensor ]
200
+
201
+ if len (tensor_positions ) > 1 :
202
+ # Add unsqueeze operations for broadcasting
203
+ broadcast_indices = []
204
+ for i , (idx , is_tensor ) in enumerate (
205
+ zip (safe_indices , is_tensor_mask , strict = False )
206
+ ):
207
+ if is_tensor :
208
+ new_idx = idx
209
+ # Add dimension for each other tensor index
210
+ for j , other_pos in enumerate (tensor_positions ):
211
+ if other_pos != i :
212
+ new_idx = new_idx .unsqueeze (j if other_pos < i else - 1 )
213
+ broadcast_indices .append (new_idx )
214
+ else :
215
+ broadcast_indices .append (idx )
216
+ values = tensor [tuple (broadcast_indices )]
217
+ else :
218
+ values = tensor [tuple (safe_indices )]
219
+
220
+ # Build validity mask
221
+ valid_mask = extra_mask .clone ()
222
+ for i , (orig_idx , is_tensor ) in enumerate (
223
+ zip (orig_indices , is_tensor_mask , strict = False )
224
+ ):
225
+ if is_tensor :
226
+ dim_size = tensor .shape [i ] if i < len (tensor .shape ) else tensor .numel ()
227
+ in_bounds = (orig_idx >= 0 ) & (orig_idx < dim_size )
228
+ # Broadcast to match mask shape by adding dimensions
229
+ # Count how many tensor indices come before and after this one
230
+ n_before = sum (1 for j in range (i ) if is_tensor_mask [j ])
231
+ n_after = sum (
232
+ 1 for j in range (i + 1 , len (is_tensor_mask )) if is_tensor_mask [j ]
233
+ )
234
+
235
+ # Add dimensions: n_after dimensions at the end, n_before at the beginning
236
+ for _ in range (n_after ):
237
+ in_bounds = in_bounds .unsqueeze (- 1 )
238
+ for _ in range (n_before ):
239
+ in_bounds = in_bounds .unsqueeze (0 )
240
+ valid_mask = valid_mask & in_bounds
241
+
242
+ return torch .where (valid_mask , values , result )
243
+
244
+
132
245
@has_side_effect
133
246
@_decorators .api (allow_host_tensor = True )
134
247
def atomic_add (
@@ -210,6 +323,59 @@ def _(
210
323
return None
211
324
212
325
326
+ @_decorators .ref (atomic_add )
327
+ def _ (
328
+ target : torch .Tensor ,
329
+ index : list [object ],
330
+ value : torch .Tensor | float ,
331
+ sem : str = "relaxed" ,
332
+ ) -> None :
333
+ """Reference implementation of atomic_add for interpret mode."""
334
+ from .. import exc
335
+ from .ref_tile import RefTile
336
+
337
+ # Validate sem parameter
338
+ if sem not in ["relaxed" , "acquire" , "release" , "acq_rel" ]:
339
+ raise exc .InternalError (
340
+ ValueError (
341
+ f"Invalid memory semantic '{ sem } '. Valid options are: relaxed, acquire, release, acq_rel"
342
+ )
343
+ )
344
+
345
+ # Convert indices to proper format
346
+ processed_index = []
347
+ for idx in index :
348
+ if isinstance (idx , RefTile ):
349
+ processed_index .append (idx ._slice )
350
+ elif isinstance (idx , torch .Tensor ) and idx .numel () == 1 :
351
+ processed_index .append (int (idx .item ()))
352
+ else :
353
+ processed_index .append (idx )
354
+
355
+ # Find tensor indices that need element-wise processing
356
+ tensor_indices = [
357
+ (i , idx )
358
+ for i , idx in enumerate (processed_index )
359
+ if isinstance (idx , torch .Tensor ) and idx .numel () > 1
360
+ ]
361
+
362
+ if tensor_indices :
363
+ # Element-wise processing for tensor indices
364
+ i , tensor_idx = tensor_indices [0 ] # Handle first tensor index
365
+ for j , elem in enumerate (tensor_idx ):
366
+ new_index = processed_index .copy ()
367
+ new_index [i ] = int (elem .item ())
368
+ val = (
369
+ value [j ]
370
+ if isinstance (value , torch .Tensor ) and value .numel () > 1
371
+ else value
372
+ )
373
+ target [tuple (new_index )] += val
374
+ else :
375
+ # Direct atomic add
376
+ target [tuple (processed_index )] += value
377
+
378
+
213
379
@_decorators .codegen (atomic_add )
214
380
def _ (state : CodegenState ) -> ast .AST :
215
381
target = state .proxy_arg (0 )
0 commit comments