@@ -142,6 +142,171 @@ def create_args_from_json(data):
142
142
return grid , args_dict
143
143
144
144
145
+ def _apply_stride_and_offset (tensor , shape , stride , storage_offset ):
146
+ """
147
+ Apply custom stride and storage offset to a tensor if needed.
148
+
149
+ Args:
150
+ tensor: The base contiguous tensor
151
+ shape: The desired shape
152
+ stride: The desired stride (or None for contiguous)
153
+ storage_offset: The desired storage offset
154
+
155
+ Returns:
156
+ torch.Tensor: The strided tensor view or original tensor if contiguous
157
+ """
158
+ if stride is None :
159
+ return tensor
160
+
161
+ # Calculate expected contiguous stride
162
+ expected_contiguous_stride = []
163
+ s = 1
164
+ for dim_size in reversed (shape ):
165
+ expected_contiguous_stride .insert (0 , s )
166
+ s *= dim_size
167
+
168
+ # If stride matches contiguous stride and no storage offset, return as-is
169
+ if tuple (stride ) == tuple (expected_contiguous_stride ) and storage_offset == 0 :
170
+ return tensor
171
+
172
+ # Calculate required storage size
173
+ if len (shape ) > 0 and len (stride ) > 0 :
174
+ max_offset = storage_offset
175
+ for dim_stride , dim_size in zip (stride , shape ):
176
+ if dim_size > 0 :
177
+ max_offset += dim_stride * (dim_size - 1 )
178
+ storage_size = max_offset + 1
179
+ else :
180
+ storage_size = storage_offset + 1
181
+
182
+ # Create larger storage tensor and create strided view
183
+ storage_tensor = torch .empty (storage_size , dtype = tensor .dtype , device = tensor .device )
184
+
185
+ # Create strided view
186
+ strided_view = storage_tensor .as_strided (
187
+ size = shape , stride = stride , storage_offset = storage_offset
188
+ )
189
+
190
+ # Copy data from the base tensor into the strided layout
191
+ strided_view .copy_ (tensor .flatten ()[: strided_view .numel ()].view (shape ))
192
+
193
+ return strided_view
194
+
195
+
196
+ def _create_base_tensor (arg_info ) -> torch .Tensor :
197
+ if arg_info .get ("blob_path" ):
198
+ return load_tensor (arg_info .get ("blob_path" ), arg_info .get ("device" ))
199
+
200
+ # Extract basic tensor properties
201
+ dtype_str = arg_info .get ("dtype" )
202
+ try :
203
+ torch_dtype = getattr (torch , dtype_str .split ("." )[- 1 ])
204
+ except AttributeError :
205
+ logging .error (f"Unsupported dtype: { dtype_str } . Defaulting to float32." )
206
+ torch_dtype = torch .float32
207
+
208
+ shape = arg_info .get ("shape" , [])
209
+ device = arg_info .get ("device" , "cpu" )
210
+
211
+ # Extract statistical information if available
212
+ mean = arg_info .get ("mean" )
213
+ std = arg_info .get ("std" )
214
+ min_val = arg_info .get ("min" )
215
+ max_val = arg_info .get ("max" )
216
+ has_stats = (
217
+ mean is not None
218
+ and std is not None
219
+ and min_val is not None
220
+ and max_val is not None
221
+ )
222
+
223
+ if arg_info .get ("tensor_capture_error" , False ):
224
+ logging .error (
225
+ f"Error: Tensor '{ arg_info .get ('name' , '' )} ' had capture error. Generating random tensor instead."
226
+ )
227
+
228
+ # Use a dummy tensor to check properties of the dtype
229
+ tensor_props = torch .empty (0 , dtype = torch_dtype )
230
+
231
+ # Case 1: Floating point types
232
+ if tensor_props .is_floating_point ():
233
+ if has_stats :
234
+ # Generate tensor with statistical properties matching original data
235
+ if std == 0 or min_val == max_val :
236
+ # Constant tensor
237
+ return torch .full (shape , mean , dtype = torch_dtype , device = device )
238
+ # Generate normal distribution with mean and std, then clamp to [min, max]
239
+ tensor = torch .randn (shape , dtype = torch .float32 , device = device ) * std + mean
240
+ tensor = torch .clamp (tensor , min = min_val , max = max_val )
241
+ return tensor .to (torch_dtype )
242
+ else :
243
+ # Fallback to original random generation
244
+ if torch_dtype in [torch .float8_e4m3fn , torch .float8_e5m2 ]:
245
+ tmp = torch .rand (shape , dtype = torch .float32 , device = device )
246
+ return tmp .to (torch_dtype )
247
+ else :
248
+ return torch .empty (shape , dtype = torch_dtype , device = device ).random_ ()
249
+
250
+ # Case 2: Integer types
251
+ elif torch_dtype in [
252
+ torch .int8 ,
253
+ torch .int16 ,
254
+ torch .int32 ,
255
+ torch .int64 ,
256
+ torch .uint8 ,
257
+ torch .bool ,
258
+ ]:
259
+ if has_stats and torch_dtype != torch .bool :
260
+ # Generate tensor with statistical properties, then round for integers
261
+ if std == 0 or min_val == max_val :
262
+ # Constant tensor
263
+ return torch .full (shape , int (mean ), dtype = torch_dtype , device = device )
264
+ tensor = torch .randn (shape , dtype = torch .float32 , device = device ) * std + mean
265
+ tensor = torch .clamp (tensor , min = min_val , max = max_val )
266
+ return torch .round (tensor ).to (torch_dtype )
267
+ else :
268
+ # Fallback to original random generation
269
+ return torch .empty (shape , dtype = torch_dtype , device = device ).random_ ()
270
+
271
+ # Case 3: Complex numbers need special handling
272
+ elif tensor_props .is_complex ():
273
+ # Complex types: fallback to original logic for now
274
+ # TODO: Could be improved to use statistical info if available
275
+ float_dtype = torch .float32 if torch_dtype == torch .complex64 else torch .float64
276
+ real_part = torch .rand (shape , dtype = float_dtype , device = device )
277
+ imag_part = torch .rand (shape , dtype = float_dtype , device = device )
278
+ return torch .complex (real_part , imag_part )
279
+
280
+ # Case 4: Handle other unsigned integers (like uint32) which fail with random_()
281
+ elif "uint" in str (torch_dtype ):
282
+ if has_stats :
283
+ # Generate tensor with statistical properties for unsigned integers
284
+ if std == 0 or min_val == max_val :
285
+ return torch .full (shape , int (mean ), dtype = torch_dtype , device = device )
286
+ tensor = torch .randn (shape , dtype = torch .float32 , device = device ) * std + mean
287
+ tensor = torch .clamp (tensor , min = min_val , max = max_val )
288
+ return torch .round (tensor ).to (torch_dtype )
289
+ else :
290
+ # Fallback to original random generation
291
+ return torch .randint (0 , 1000 , shape , dtype = torch_dtype , device = device )
292
+
293
+ # Case 5: If we don't know how to handle the type, raise an error
294
+ else :
295
+ raise NotImplementedError (
296
+ f"Random data generation not implemented for dtype: { torch_dtype } "
297
+ )
298
+
299
+
300
+ def _create_tensor (arg_info ) -> torch .Tensor :
301
+ tensor = _create_base_tensor (arg_info )
302
+
303
+ # Apply stride and storage offset if needed
304
+ shape = arg_info .get ("shape" , [])
305
+ stride = arg_info .get ("stride" )
306
+ storage_offset = arg_info .get ("storage_offset" , 0 )
307
+ return _apply_stride_and_offset (tensor , shape , stride , storage_offset )
308
+
309
+
145
310
def _create_arg_from_info (arg_info ):
146
311
"""
147
312
Recursively construct a kernel argument from its JSON schema.
@@ -166,121 +331,7 @@ def _create_arg_from_info(arg_info):
166
331
return arg_info .get ("value" )
167
332
168
333
elif arg_type == "tensor" :
169
- if arg_info .get ("blob_path" ):
170
- return load_tensor (arg_info .get ("blob_path" ), arg_info .get ("device" ))
171
-
172
- # Extract basic tensor properties
173
- dtype_str = arg_info .get ("dtype" )
174
- try :
175
- torch_dtype = getattr (torch , dtype_str .split ("." )[- 1 ])
176
- except AttributeError :
177
- logging .error (f"Unsupported dtype: { dtype_str } . Defaulting to float32." )
178
- torch_dtype = torch .float32
179
-
180
- shape = arg_info .get ("shape" , [])
181
- device = arg_info .get ("device" , "cpu" )
182
-
183
- # Extract statistical information if available
184
- mean = arg_info .get ("mean" )
185
- std = arg_info .get ("std" )
186
- min_val = arg_info .get ("min" )
187
- max_val = arg_info .get ("max" )
188
- has_stats = (
189
- mean is not None
190
- and std is not None
191
- and min_val is not None
192
- and max_val is not None
193
- )
194
-
195
- if arg_info .get ("tensor_capture_error" , False ):
196
- logging .error (
197
- f"Error: Tensor '{ arg_info .get ('name' , '' )} ' had capture error. Generating random tensor instead."
198
- )
199
-
200
- # Use a dummy tensor to check properties of the dtype
201
- tensor_props = torch .empty (0 , dtype = torch_dtype )
202
-
203
- # Case 1: Floating point types
204
- if tensor_props .is_floating_point ():
205
- if has_stats :
206
- # Generate tensor with statistical properties matching original data
207
- if std == 0 or min_val == max_val :
208
- # Constant tensor
209
- return torch .full (shape , mean , dtype = torch_dtype , device = device )
210
- # Generate normal distribution with mean and std, then clamp to [min, max]
211
- tensor = (
212
- torch .randn (shape , dtype = torch .float32 , device = device ) * std + mean
213
- )
214
- tensor = torch .clamp (tensor , min = min_val , max = max_val )
215
- return tensor .to (torch_dtype )
216
- else :
217
- # Fallback to original random generation
218
- if torch_dtype in [torch .float8_e4m3fn , torch .float8_e5m2 ]:
219
- tmp = torch .rand (shape , dtype = torch .float32 , device = device )
220
- return tmp .to (torch_dtype )
221
- else :
222
- return torch .empty (
223
- shape , dtype = torch_dtype , device = device
224
- ).random_ ()
225
-
226
- # Case 2: Integer types
227
- elif torch_dtype in [
228
- torch .int8 ,
229
- torch .int16 ,
230
- torch .int32 ,
231
- torch .int64 ,
232
- torch .uint8 ,
233
- torch .bool ,
234
- ]:
235
- if has_stats and torch_dtype != torch .bool :
236
- # Generate tensor with statistical properties, then round for integers
237
- if std == 0 or min_val == max_val :
238
- # Constant tensor
239
- return torch .full (
240
- shape , int (mean ), dtype = torch_dtype , device = device
241
- )
242
- tensor = (
243
- torch .randn (shape , dtype = torch .float32 , device = device ) * std + mean
244
- )
245
- tensor = torch .clamp (tensor , min = min_val , max = max_val )
246
- return torch .round (tensor ).to (torch_dtype )
247
- else :
248
- # Fallback to original random generation
249
- return torch .empty (shape , dtype = torch_dtype , device = device ).random_ ()
250
-
251
- # Case 3: Complex numbers need special handling
252
- elif tensor_props .is_complex ():
253
- # Complex types: fallback to original logic for now
254
- # TODO: Could be improved to use statistical info if available
255
- float_dtype = (
256
- torch .float32 if torch_dtype == torch .complex64 else torch .float64
257
- )
258
- real_part = torch .rand (shape , dtype = float_dtype , device = device )
259
- imag_part = torch .rand (shape , dtype = float_dtype , device = device )
260
- return torch .complex (real_part , imag_part )
261
-
262
- # Case 4: Handle other unsigned integers (like uint32) which fail with random_()
263
- elif "uint" in str (torch_dtype ):
264
- if has_stats :
265
- # Generate tensor with statistical properties for unsigned integers
266
- if std == 0 or min_val == max_val :
267
- return torch .full (
268
- shape , int (mean ), dtype = torch_dtype , device = device
269
- )
270
- tensor = (
271
- torch .randn (shape , dtype = torch .float32 , device = device ) * std + mean
272
- )
273
- tensor = torch .clamp (tensor , min = min_val , max = max_val )
274
- return torch .round (tensor ).to (torch_dtype )
275
- else :
276
- # Fallback to original random generation
277
- return torch .randint (0 , 1000 , shape , dtype = torch_dtype , device = device )
278
-
279
- # Case 5: If we don't know how to handle the type, raise an error
280
- else :
281
- raise NotImplementedError (
282
- f"Random data generation not implemented for dtype: { torch_dtype } "
283
- )
334
+ return _create_tensor (arg_info )
284
335
285
336
elif arg_type == "triton_kernels.tensor.Tensor" :
286
337
if not TRITON_KERNELS_CUSTOM_TYPES :
0 commit comments