Skip to content

Commit 12f1d1b

Browse files
wychimeta-codesync[bot]
authored andcommitted
Fix non-contiguous tensor recreation in TritonParse reproducer
Summary: The TritonParse reproducer was generating contiguous tensors when recreating kernel arguments from JSON, but the original tensors had non-contiguous strides and storage offsets. This caused CUDA out-of-bounds memory access errors when kernels computed pointer offsets using the stride information. For example, the K tensor had shape [995415, 4, 64] with stride [1024, 64, 1] and storage_offset 768, but the reproducer created a contiguous tensor with stride [256, 64, 1] and no offset. When the kernel computed addresses like `K + off_h * stride_kh + seq_start * stride_kn`, it would access invalid memory. This change adds proper handling for non-contiguous tensors by: 1. Extracting stride and storage_offset from JSON metadata 2. Creating a properly sized storage buffer 3. Using as_strided() to create views with the correct memory layout The logic is refactored into helper functions to maintain clean code structure: - _apply_stride_and_offset(): Handles stride/offset application - _create_base_tensor(): Creates the base contiguous tensor with data - _create_tensor(): Orchestrates the full tensor creation pipeline Reviewed By: FindHao Differential Revision: D84100085 fbshipit-source-id: c9618cd797bfe65c3ebeeff03ec054be56103d2f
1 parent 9fc4edd commit 12f1d1b

File tree

1 file changed

+166
-115
lines changed

1 file changed

+166
-115
lines changed

tritonparse/reproducer/templates/example.py

Lines changed: 166 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,171 @@ def create_args_from_json(data):
142142
return grid, args_dict
143143

144144

145+
def _apply_stride_and_offset(tensor, shape, stride, storage_offset):
146+
"""
147+
Apply custom stride and storage offset to a tensor if needed.
148+
149+
Args:
150+
tensor: The base contiguous tensor
151+
shape: The desired shape
152+
stride: The desired stride (or None for contiguous)
153+
storage_offset: The desired storage offset
154+
155+
Returns:
156+
torch.Tensor: The strided tensor view or original tensor if contiguous
157+
"""
158+
if stride is None:
159+
return tensor
160+
161+
# Calculate expected contiguous stride
162+
expected_contiguous_stride = []
163+
s = 1
164+
for dim_size in reversed(shape):
165+
expected_contiguous_stride.insert(0, s)
166+
s *= dim_size
167+
168+
# If stride matches contiguous stride and no storage offset, return as-is
169+
if tuple(stride) == tuple(expected_contiguous_stride) and storage_offset == 0:
170+
return tensor
171+
172+
# Calculate required storage size
173+
if len(shape) > 0 and len(stride) > 0:
174+
max_offset = storage_offset
175+
for dim_stride, dim_size in zip(stride, shape):
176+
if dim_size > 0:
177+
max_offset += dim_stride * (dim_size - 1)
178+
storage_size = max_offset + 1
179+
else:
180+
storage_size = storage_offset + 1
181+
182+
# Create larger storage tensor and create strided view
183+
storage_tensor = torch.empty(storage_size, dtype=tensor.dtype, device=tensor.device)
184+
185+
# Create strided view
186+
strided_view = storage_tensor.as_strided(
187+
size=shape, stride=stride, storage_offset=storage_offset
188+
)
189+
190+
# Copy data from the base tensor into the strided layout
191+
strided_view.copy_(tensor.flatten()[: strided_view.numel()].view(shape))
192+
193+
return strided_view
194+
195+
196+
def _create_base_tensor(arg_info) -> torch.Tensor:
197+
if arg_info.get("blob_path"):
198+
return load_tensor(arg_info.get("blob_path"), arg_info.get("device"))
199+
200+
# Extract basic tensor properties
201+
dtype_str = arg_info.get("dtype")
202+
try:
203+
torch_dtype = getattr(torch, dtype_str.split(".")[-1])
204+
except AttributeError:
205+
logging.error(f"Unsupported dtype: {dtype_str}. Defaulting to float32.")
206+
torch_dtype = torch.float32
207+
208+
shape = arg_info.get("shape", [])
209+
device = arg_info.get("device", "cpu")
210+
211+
# Extract statistical information if available
212+
mean = arg_info.get("mean")
213+
std = arg_info.get("std")
214+
min_val = arg_info.get("min")
215+
max_val = arg_info.get("max")
216+
has_stats = (
217+
mean is not None
218+
and std is not None
219+
and min_val is not None
220+
and max_val is not None
221+
)
222+
223+
if arg_info.get("tensor_capture_error", False):
224+
logging.error(
225+
f"Error: Tensor '{arg_info.get('name', '')}' had capture error. Generating random tensor instead."
226+
)
227+
228+
# Use a dummy tensor to check properties of the dtype
229+
tensor_props = torch.empty(0, dtype=torch_dtype)
230+
231+
# Case 1: Floating point types
232+
if tensor_props.is_floating_point():
233+
if has_stats:
234+
# Generate tensor with statistical properties matching original data
235+
if std == 0 or min_val == max_val:
236+
# Constant tensor
237+
return torch.full(shape, mean, dtype=torch_dtype, device=device)
238+
# Generate normal distribution with mean and std, then clamp to [min, max]
239+
tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean
240+
tensor = torch.clamp(tensor, min=min_val, max=max_val)
241+
return tensor.to(torch_dtype)
242+
else:
243+
# Fallback to original random generation
244+
if torch_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
245+
tmp = torch.rand(shape, dtype=torch.float32, device=device)
246+
return tmp.to(torch_dtype)
247+
else:
248+
return torch.empty(shape, dtype=torch_dtype, device=device).random_()
249+
250+
# Case 2: Integer types
251+
elif torch_dtype in [
252+
torch.int8,
253+
torch.int16,
254+
torch.int32,
255+
torch.int64,
256+
torch.uint8,
257+
torch.bool,
258+
]:
259+
if has_stats and torch_dtype != torch.bool:
260+
# Generate tensor with statistical properties, then round for integers
261+
if std == 0 or min_val == max_val:
262+
# Constant tensor
263+
return torch.full(shape, int(mean), dtype=torch_dtype, device=device)
264+
tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean
265+
tensor = torch.clamp(tensor, min=min_val, max=max_val)
266+
return torch.round(tensor).to(torch_dtype)
267+
else:
268+
# Fallback to original random generation
269+
return torch.empty(shape, dtype=torch_dtype, device=device).random_()
270+
271+
# Case 3: Complex numbers need special handling
272+
elif tensor_props.is_complex():
273+
# Complex types: fallback to original logic for now
274+
# TODO: Could be improved to use statistical info if available
275+
float_dtype = torch.float32 if torch_dtype == torch.complex64 else torch.float64
276+
real_part = torch.rand(shape, dtype=float_dtype, device=device)
277+
imag_part = torch.rand(shape, dtype=float_dtype, device=device)
278+
return torch.complex(real_part, imag_part)
279+
280+
# Case 4: Handle other unsigned integers (like uint32) which fail with random_()
281+
elif "uint" in str(torch_dtype):
282+
if has_stats:
283+
# Generate tensor with statistical properties for unsigned integers
284+
if std == 0 or min_val == max_val:
285+
return torch.full(shape, int(mean), dtype=torch_dtype, device=device)
286+
tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean
287+
tensor = torch.clamp(tensor, min=min_val, max=max_val)
288+
return torch.round(tensor).to(torch_dtype)
289+
else:
290+
# Fallback to original random generation
291+
return torch.randint(0, 1000, shape, dtype=torch_dtype, device=device)
292+
293+
# Case 5: If we don't know how to handle the type, raise an error
294+
else:
295+
raise NotImplementedError(
296+
f"Random data generation not implemented for dtype: {torch_dtype}"
297+
)
298+
299+
300+
def _create_tensor(arg_info) -> torch.Tensor:
301+
tensor = _create_base_tensor(arg_info)
302+
303+
# Apply stride and storage offset if needed
304+
shape = arg_info.get("shape", [])
305+
stride = arg_info.get("stride")
306+
storage_offset = arg_info.get("storage_offset", 0)
307+
return _apply_stride_and_offset(tensor, shape, stride, storage_offset)
308+
309+
145310
def _create_arg_from_info(arg_info):
146311
"""
147312
Recursively construct a kernel argument from its JSON schema.
@@ -166,121 +331,7 @@ def _create_arg_from_info(arg_info):
166331
return arg_info.get("value")
167332

168333
elif arg_type == "tensor":
169-
if arg_info.get("blob_path"):
170-
return load_tensor(arg_info.get("blob_path"), arg_info.get("device"))
171-
172-
# Extract basic tensor properties
173-
dtype_str = arg_info.get("dtype")
174-
try:
175-
torch_dtype = getattr(torch, dtype_str.split(".")[-1])
176-
except AttributeError:
177-
logging.error(f"Unsupported dtype: {dtype_str}. Defaulting to float32.")
178-
torch_dtype = torch.float32
179-
180-
shape = arg_info.get("shape", [])
181-
device = arg_info.get("device", "cpu")
182-
183-
# Extract statistical information if available
184-
mean = arg_info.get("mean")
185-
std = arg_info.get("std")
186-
min_val = arg_info.get("min")
187-
max_val = arg_info.get("max")
188-
has_stats = (
189-
mean is not None
190-
and std is not None
191-
and min_val is not None
192-
and max_val is not None
193-
)
194-
195-
if arg_info.get("tensor_capture_error", False):
196-
logging.error(
197-
f"Error: Tensor '{arg_info.get('name', '')}' had capture error. Generating random tensor instead."
198-
)
199-
200-
# Use a dummy tensor to check properties of the dtype
201-
tensor_props = torch.empty(0, dtype=torch_dtype)
202-
203-
# Case 1: Floating point types
204-
if tensor_props.is_floating_point():
205-
if has_stats:
206-
# Generate tensor with statistical properties matching original data
207-
if std == 0 or min_val == max_val:
208-
# Constant tensor
209-
return torch.full(shape, mean, dtype=torch_dtype, device=device)
210-
# Generate normal distribution with mean and std, then clamp to [min, max]
211-
tensor = (
212-
torch.randn(shape, dtype=torch.float32, device=device) * std + mean
213-
)
214-
tensor = torch.clamp(tensor, min=min_val, max=max_val)
215-
return tensor.to(torch_dtype)
216-
else:
217-
# Fallback to original random generation
218-
if torch_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
219-
tmp = torch.rand(shape, dtype=torch.float32, device=device)
220-
return tmp.to(torch_dtype)
221-
else:
222-
return torch.empty(
223-
shape, dtype=torch_dtype, device=device
224-
).random_()
225-
226-
# Case 2: Integer types
227-
elif torch_dtype in [
228-
torch.int8,
229-
torch.int16,
230-
torch.int32,
231-
torch.int64,
232-
torch.uint8,
233-
torch.bool,
234-
]:
235-
if has_stats and torch_dtype != torch.bool:
236-
# Generate tensor with statistical properties, then round for integers
237-
if std == 0 or min_val == max_val:
238-
# Constant tensor
239-
return torch.full(
240-
shape, int(mean), dtype=torch_dtype, device=device
241-
)
242-
tensor = (
243-
torch.randn(shape, dtype=torch.float32, device=device) * std + mean
244-
)
245-
tensor = torch.clamp(tensor, min=min_val, max=max_val)
246-
return torch.round(tensor).to(torch_dtype)
247-
else:
248-
# Fallback to original random generation
249-
return torch.empty(shape, dtype=torch_dtype, device=device).random_()
250-
251-
# Case 3: Complex numbers need special handling
252-
elif tensor_props.is_complex():
253-
# Complex types: fallback to original logic for now
254-
# TODO: Could be improved to use statistical info if available
255-
float_dtype = (
256-
torch.float32 if torch_dtype == torch.complex64 else torch.float64
257-
)
258-
real_part = torch.rand(shape, dtype=float_dtype, device=device)
259-
imag_part = torch.rand(shape, dtype=float_dtype, device=device)
260-
return torch.complex(real_part, imag_part)
261-
262-
# Case 4: Handle other unsigned integers (like uint32) which fail with random_()
263-
elif "uint" in str(torch_dtype):
264-
if has_stats:
265-
# Generate tensor with statistical properties for unsigned integers
266-
if std == 0 or min_val == max_val:
267-
return torch.full(
268-
shape, int(mean), dtype=torch_dtype, device=device
269-
)
270-
tensor = (
271-
torch.randn(shape, dtype=torch.float32, device=device) * std + mean
272-
)
273-
tensor = torch.clamp(tensor, min=min_val, max=max_val)
274-
return torch.round(tensor).to(torch_dtype)
275-
else:
276-
# Fallback to original random generation
277-
return torch.randint(0, 1000, shape, dtype=torch_dtype, device=device)
278-
279-
# Case 5: If we don't know how to handle the type, raise an error
280-
else:
281-
raise NotImplementedError(
282-
f"Random data generation not implemented for dtype: {torch_dtype}"
283-
)
334+
return _create_tensor(arg_info)
284335

285336
elif arg_type == "triton_kernels.tensor.Tensor":
286337
if not TRITON_KERNELS_CUSTOM_TYPES:

0 commit comments

Comments
 (0)