Skip to content

Commit 8dd8510

Browse files
[cherry-pick][inductor][triton] Update HAS_WARP_SPEC to check triton.Config params. Update Triton Hash to top of release/3.4.x stack (pytorch#158646)
* [inductor][triton] Update HAS_WARP_SPEC to check triton.Config params. Update Triton Hash to top of release/3.4.x stack (pytorch#158459) Update triton commit hash to `11ec6354315768a85da41032535e3b7b99c5f706`, which is the new release/3.4.x branch in triton-lang/triton. Also, update HAS_WARP_SPEC handling: In triton 3.4, warp spec will have a different interface: num_consumer_groups will be determined automatically by the compiler. This breaks the current Inductor integration, so for now, update HAS_WARP_SPEC to check whether triton.Config takes num_consumer_groups and num_buffers_warp_spec as parameters. Pull Request resolved: pytorch#158459 Approved by: https://github.com/atalman * dont_upde_hash * Revert "dont_upde_hash" This reverts commit 5fffb12. * fix_docker_builds --------- Co-authored-by: David Berard <[email protected]>
1 parent 40e7433 commit 8dd8510

File tree

4 files changed

+15
-33
lines changed

4 files changed

+15
-33
lines changed

.ci/docker/ci_commit_pins/triton.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
ae848267bebc65c6181e8cc5e64a6357d2679260
1+
11ec6354315768a85da41032535e3b7b99c5f706

.ci/docker/common/install_conda.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then
5454
export SYSROOT_DEP="sysroot_linux-64=2.17"
5555
fi
5656

57+
# Please see: https://github.com/pytorch/pytorch/pull/158370#issuecomment-3084705725
58+
export CONDA_PLUGINS_AUTO_ACCEPT_TOS="yes"
5759
# Install correct Python version
5860
# Also ensure sysroot is using a modern GLIBC to match system compilers
5961
as_jenkins conda create -n py_$ANACONDA_PYTHON_VERSION -y\

test/inductor/test_static_cuda_launcher.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from torch._inductor.test_case import TestCase
1515
from torch.testing._internal.common_utils import skipIfRocm
1616
from torch.testing._internal.triton_utils import requires_cuda
17-
from torch.torch_version import TorchVersion
1817

1918

2019
@requires_cuda
@@ -140,36 +139,6 @@ def signed_integers(
140139
launcher.run(1, 1, 1, stream, new_arg0, 50, 50, 50, 50)
141140
self.assertEqual(new_arg0, arg0)
142141

143-
# TODO: floats don't work properly, triton seems to think they're all tl.float32
144-
# despite type annotations.
145-
# There's also not really a good way for me to make a float16 in python...
146-
@skipIfRocm
147-
def test_floats(self):
148-
@triton.jit
149-
def floats(arg0, arg1: tl.float16, arg2: tl.float32, arg3: tl.float64):
150-
x = tl.load(arg0)
151-
y = arg1 + arg2 + arg3
152-
tl.store(arg0, x + y)
153-
154-
arg0 = torch.zeros(1, dtype=torch.float64, device="cuda")
155-
156-
args = (arg0, 1.0, 1.0, 1.0)
157-
158-
compiled_kernel = floats[1,](*args)
159-
launcher = self._make_launcher(compiled_kernel)
160-
if TorchVersion(triton.__version__) >= TorchVersion("3.4.0"):
161-
self.assertEqual(launcher.arg_tys, "Offd")
162-
else:
163-
self.assertEqual(launcher.arg_tys, "Offf")
164-
# TODO this line fails on Triton 3.4.0 (https://github.com/triton-lang/triton/issues/6176)
165-
# Add the check back when this is fixed in Triton
166-
# self.assertEqual(arg0, torch.tensor([3.0], dtype=torch.float64, device="cuda"))
167-
new_arg0 = torch.zeros(1, dtype=torch.float64, device="cuda")
168-
device_interface = get_interface_for_device("cuda")
169-
stream = device_interface.get_raw_stream(device_interface.current_device())
170-
launcher.run(1, 1, 1, stream, new_arg0, 1.0, 1.0, 1.0)
171-
self.assertEqual(new_arg0, arg0)
172-
173142
@skipIfRocm
174143
def test_basic_1arg(self):
175144
@triton.jit

torch/_inductor/runtime/triton_compat.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,18 @@ def GPUTarget(
6969
def _log2(x: Any) -> Any:
7070
raise NotImplementedError
7171

72-
HAS_WARP_SPEC = hasattr(tl, "async_task")
72+
def _triton_config_has(param_name: str) -> bool:
73+
if not hasattr(triton, "Config"):
74+
return False
75+
if not hasattr(triton.Config, "__init__"):
76+
return False
77+
return param_name in inspect.signature(triton.Config.__init__).parameters
78+
79+
HAS_WARP_SPEC = (
80+
hasattr(tl, "async_task")
81+
and _triton_config_has("num_consumer_groups")
82+
and _triton_config_has("num_buffers_warp_spec")
83+
)
7384

7485
try:
7586
from triton import knobs

0 commit comments

Comments
 (0)