Skip to content

Commit 08d7f64

Browse files
authored
[INTERPRETER] Improve support for aliased tensors in interpreter (#5890)
If two tensors alias, the interpreter will clobber writes when it makes a CPU copy of its backing storage. Identifying this in general is quite difficult, but in most cases aliasing tensors share the same backing storage. If we set up our CPU-side storage to mirror the device, then we can avoid this particular bug in that case. This partially fixes #5791. <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [X] I am not making a trivial change, such as fixing a typo in a comment. - [ ] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [X] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [X] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [X] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent b2a86b1 commit 08d7f64

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

python/test/unit/language/test_core.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7025,3 +7025,15 @@ def _simple_add(
70257025
_simple_add[grid](x, x.stride(0), x.stride(1))
70267026

70277027
assert torch.allclose(x, torch.ones_like(x) * c_dim)
7028+
7029+
7030+
@pytest.mark.interpreter
7031+
def test_aliasing(device):
7032+
7033+
@triton.jit
7034+
def aliasing_kernel(buffer, buffer2):
7035+
triton.language.store(buffer, 1)
7036+
7037+
buffer = torch.zeros(1, device=device)
7038+
aliasing_kernel[(1, )](buffer, buffer)
7039+
assert buffer[0] == 1

python/triton/runtime/interpreter.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,16 +1157,22 @@ def __init__(self, fn, arg_names, grid):
11571157
self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"]
11581158

11591159
def _init_args_hst(self, args_dev, kwargs):
1160+
storages = {}
11601161

11611162
def _to_cpu(arg):
11621163
if isinstance(arg, tuple):
11631164
return _tuple_create(arg, map(_to_cpu, arg))
11641165
elif not hasattr(arg, "data_ptr"):
11651166
return arg
1167+
11661168
unwrapped_arg = _unwrap_tensor(arg)
1169+
if unwrapped_arg.untyped_storage().data_ptr() not in storages:
1170+
storage = unwrapped_arg.untyped_storage()
1171+
storages[storage.data_ptr()] = storage.cpu()
1172+
1173+
storage = storages[unwrapped_arg.untyped_storage().data_ptr()]
11671174
cpu_arg = unwrapped_arg.new_empty(0, device='cpu')
1168-
cpu_arg.set_(unwrapped_arg.untyped_storage().cpu(), unwrapped_arg.storage_offset(), unwrapped_arg.size(),
1169-
unwrapped_arg.stride())
1175+
cpu_arg.set_(storage, unwrapped_arg.storage_offset(), unwrapped_arg.size(), unwrapped_arg.stride())
11701176
cpu_arg = _rewrap_tensor(cpu_arg, original_tensor=arg)
11711177
return cpu_arg
11721178

@@ -1175,21 +1181,17 @@ def _to_cpu(arg):
11751181
# Process keyword arguments
11761182
kwargs_hst = {}
11771183
for key, value in kwargs.items():
1178-
if hasattr(value, "data_ptr"):
1179-
kwargs_hst[key] = value.cpu()
1180-
elif isinstance(value, tuple):
1181-
return _tuple_create(value, map(_to_cpu, value))
1182-
else:
1183-
kwargs_hst[key] = value
1184+
kwargs_hst[key] = _to_cpu(value)
11841185
return args_hst, kwargs_hst
11851186

11861187
def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst):
1188+
storages = {}
11871189

11881190
def _from_cpu(arg_dev, arg_hst):
11891191
if hasattr(arg_dev, "data_ptr"):
11901192
# No need to rewrap because this just modifies internal
11911193
arg_dev, arg_hst = _unwrap_tensor(arg_dev), _unwrap_tensor(arg_hst)
1192-
arg_dev.untyped_storage().copy_(arg_hst.untyped_storage())
1194+
storages[arg_dev.untyped_storage().data_ptr()] = (arg_dev.untyped_storage(), arg_hst.untyped_storage())
11931195
elif isinstance(arg_dev, tuple):
11941196
for (arg_dev, arg_hst) in zip(arg_dev, arg_hst):
11951197
_from_cpu(arg_dev, arg_hst)
@@ -1202,6 +1204,9 @@ def _from_cpu(arg_dev, arg_hst):
12021204
kwarg_hst = kwargs_hst[key]
12031205
_from_cpu(kwarg_dev, kwarg_hst)
12041206

1207+
for (arg_dev, arg_hst) in storages.values():
1208+
arg_dev.copy_(arg_hst)
1209+
12051210
def __call__(self, *args_dev, **kwargs):
12061211
if kwargs.pop("warmup", False):
12071212
return

0 commit comments

Comments
 (0)