Skip to content

Commit cf6b36e

Browse files
committed
try to fix E2E
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 50b77e2 commit cf6b36e

File tree

2 files changed

+34
-8
lines changed

2 files changed

+34
-8
lines changed

scripts/patch-pytorch.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,4 @@ echo "Applying PyTorch patches in $REPO_ROOT"
1717
cd "$REPO_ROOT"
1818

1919
curl -sSL https://github.com/pytorch/pytorch/pull/126516.diff | git apply -
20-
2120
git apply "${SCRIPT_DIR}/pytorch.patch"

scripts/pytorch.patch

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,22 +123,40 @@ index 2ab2b326354..42d76b8bf94 100644
123123
"configs": [
124124
config_of(
125125
diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py
126-
index 276c01f3f42..5c633b7963b 100644
126+
index fa2a1334380..11de4dc6e5f 100644
127127
--- a/torch/_inductor/runtime/hints.py
128128
+++ b/torch/_inductor/runtime/hints.py
129-
@@ -48,8 +48,8 @@ if _is_triton_available():
129+
@@ -44,25 +44,16 @@ def _is_triton_available() -> bool:
130+
# Define `AttrsDescriptorWrapper` function with clear conditional handling
131+
if _is_triton_available():
132+
try:
133+
- from triton.backends.compiler import AttrsDescriptor
134+
135+
def AttrsDescriptorWrapper(
136+
divisible_by_16=None,
137+
equal_to_1=None,
130138
):
131-
# Prepare the arguments for AttrsDescriptor
139+
- # Prepare the arguments for AttrsDescriptor
132140
kwargs = {
133141
- "tt.divisibility": divisible_by_16,
134142
- "tt.equal_to": equal_to_1,
135-
+ "tt.divisibility": tuple([(i,) for i in divisible_by_16]),
136-
+ "tt.equal_to": tuple([(i,) for i in equal_to_1]),
143+
+ tuple([(i,) for i in divisible_by_16]): [["tt.divisibility", 16]],
144+
+ tuple([(i,) for i in equal_to_1]): [["tt.equal_to", 1]],
137145
}
146+
-
147+
- # Instantiate AttrsDescriptor with the prepared arguments
148+
- res = AttrsDescriptor.from_dict(
149+
- {"arg_properties": kwargs, "cls": AttrsDescriptor.__name__}
150+
- )
151+
- assert res.property_values["tt.divisibility"] == 16
152+
- assert res.property_values["tt.equal_to"] == 1
153+
- return res
154+
+ return kwargs
138155

139-
# Instantiate AttrsDescriptor with the prepared arguments
156+
except ImportError:
157+
from triton.compiler.compiler import AttrsDescriptor
140158
diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py
141-
index 281d0e78ba4..901263df4aa 100644
159+
index 281d0e78ba4..302700dfe81 100644
142160
--- a/torch/_inductor/runtime/triton_heuristics.py
143161
+++ b/torch/_inductor/runtime/triton_heuristics.py
144162
@@ -414,10 +414,21 @@ class CachingAutotuner(KernelInterface):
@@ -164,3 +182,12 @@ index 281d0e78ba4..901263df4aa 100644
164182
compile_meta["constants"],
165183
compile_meta["configs"][0],
166184
),
185+
@@ -502,7 +513,7 @@ class CachingAutotuner(KernelInterface):
186+
call_args = [
187+
arg
188+
for i, arg in enumerate(self.fn.arg_names)
189+
- if i not in self.fn.constexprs and arg not in none_args
190+
+ if arg not in none_args
191+
]
192+
193+
def_args = [

0 commit comments

Comments
 (0)