Skip to content

Commit 8697d29

Browse files
committed
[intel] update driver.py and pytorch patches
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 1ff0b7d commit 8697d29

File tree

4 files changed

+95
-30
lines changed

4 files changed

+95
-30
lines changed

python/tutorials/02-fused-softmax.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -175,14 +175,7 @@ def allocated_slm_size(size_smem):
175175
num_programs = n_rows
176176

177177
# Create a number of persistent programs.
178-
kernel[(num_programs, 1, 1)](
179-
y,
180-
x,
181-
x.stride(0),
182-
y.stride(0),
183-
n_rows,
184-
n_cols,
185-
)
178+
kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE)
186179
return y
187180

188181

scripts/patch-pytorch.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,5 @@ echo "Applying PyTorch patches in $REPO_ROOT"
1717
cd "$REPO_ROOT"
1818

1919
curl -sSL https://github.com/pytorch/pytorch/pull/126516.diff | git apply -
20-
20+
# build trigger
2121
git apply "${SCRIPT_DIR}/pytorch.patch"

scripts/pytorch.patch

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,24 @@ index 00031a56b8d..59086d41b40 100644
110110
self.triton_meta = triton_meta
111111

112112
for tree in self.range_trees:
113+
diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py
114+
index 8b8c29bbb15..c89a76e9868 100644
115+
--- a/torch/_inductor/codegen/triton_utils.py
116+
+++ b/torch/_inductor/codegen/triton_utils.py
117+
@@ -165,12 +165,4 @@ def config_of(
118+
else:
119+
divisible_by_16 = ()
120+
121+
- equal_to_1 = tuple(
122+
- i
123+
- for i, arg in zip(indices, args)
124+
- if isinstance(arg, SizeArg)
125+
- and isinstance(arg.expr, (int, sympy.Integer))
126+
- and V.graph.sizevars.statically_known_equals(arg.expr, 1) # type: ignore[arg-type]
127+
- )
128+
-
129+
- return AttrsDescriptorWrapper(divisible_by_16, equal_to_1)
130+
+ return AttrsDescriptorWrapper(divisible_by_16)
113131
diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py
114132
index 2ab2b326354..42d76b8bf94 100644
115133
--- a/torch/_inductor/codegen/wrapper.py
@@ -123,22 +141,39 @@ index 2ab2b326354..42d76b8bf94 100644
123141
"configs": [
124142
config_of(
125143
diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py
126-
index 276c01f3f42..5c633b7963b 100644
144+
index fa2a1334380..4d730fd45de 100644
127145
--- a/torch/_inductor/runtime/hints.py
128146
+++ b/torch/_inductor/runtime/hints.py
129-
@@ -48,8 +48,8 @@ if _is_triton_available():
147+
@@ -44,25 +44,14 @@ def _is_triton_available() -> bool:
148+
# Define `AttrsDescriptorWrapper` function with clear conditional handling
149+
if _is_triton_available():
150+
try:
151+
- from triton.backends.compiler import AttrsDescriptor
152+
153+
def AttrsDescriptorWrapper(
154+
divisible_by_16=None,
155+
- equal_to_1=None,
130156
):
131-
# Prepare the arguments for AttrsDescriptor
157+
- # Prepare the arguments for AttrsDescriptor
132158
kwargs = {
133159
- "tt.divisibility": divisible_by_16,
134160
- "tt.equal_to": equal_to_1,
135-
+ "tt.divisibility": tuple([(i,) for i in divisible_by_16]),
136-
+ "tt.equal_to": tuple([(i,) for i in equal_to_1]),
161+
+ tuple([(i,) for i in divisible_by_16]): [["tt.divisibility", 16]],
137162
}
163+
-
164+
- # Instantiate AttrsDescriptor with the prepared arguments
165+
- res = AttrsDescriptor.from_dict(
166+
- {"arg_properties": kwargs, "cls": AttrsDescriptor.__name__}
167+
- )
168+
- assert res.property_values["tt.divisibility"] == 16
169+
- assert res.property_values["tt.equal_to"] == 1
170+
- return res
171+
+ return kwargs
138172

139-
# Instantiate AttrsDescriptor with the prepared arguments
173+
except ImportError:
174+
from triton.compiler.compiler import AttrsDescriptor
140175
diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py
141-
index 281d0e78ba4..901263df4aa 100644
176+
index 281d0e78ba4..3b059a365c9 100644
142177
--- a/torch/_inductor/runtime/triton_heuristics.py
143178
+++ b/torch/_inductor/runtime/triton_heuristics.py
144179
@@ -414,10 +414,21 @@ class CachingAutotuner(KernelInterface):
@@ -164,3 +199,33 @@ index 281d0e78ba4..901263df4aa 100644
164199
compile_meta["constants"],
165200
compile_meta["configs"][0],
166201
),
202+
@@ -502,13 +513,11 @@ class CachingAutotuner(KernelInterface):
203+
call_args = [
204+
arg
205+
for i, arg in enumerate(self.fn.arg_names)
206+
- if i not in self.fn.constexprs and arg not in none_args
207+
]
208+
209+
def_args = [
210+
name
211+
for name in self.fn.arg_names
212+
- if name not in cfg.kwargs and name not in none_args
213+
]
214+
binary_shared = (
215+
binary.shared if hasattr(binary, "shared") else binary.metadata.shared
216+
@@ -952,6 +961,7 @@ class CachingAutotuner(KernelInterface):
217+
):
218+
return launcher(
219+
*args,
220+
+ **launcher.config.kwargs,
221+
**kwargs,
222+
grid=grid,
223+
stream=stream,
224+
@@ -959,6 +969,7 @@ class CachingAutotuner(KernelInterface):
225+
else:
226+
return launcher(
227+
*args,
228+
+ **launcher.config.kwargs,
229+
**kwargs,
230+
grid=grid,
231+
stream=stream,

third_party/intel/backend/driver.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def wait(self):
194194

195195

196196
def ty_to_cpp(ty):
197-
if ty[0] == '*' or ty == "none":
197+
if ty[0] == '*':
198198
return "void*"
199199
return {
200200
"i1": "int32_t",
@@ -215,10 +215,12 @@ def ty_to_cpp(ty):
215215
}[ty]
216216

217217

218-
def make_launcher(constants, signature, ids):
218+
def make_launcher(constants, signature):
219219

220220
def _extracted_type(ty):
221-
if ty[0] == '*' or ty == "none":
221+
if ty == "constexpr":
222+
return "PyObject*"
223+
if ty[0] == '*':
222224
return "PyObject*"
223225
if ty[0] == '[':
224226
if ty == "[]":
@@ -252,7 +254,6 @@ def format_of(ty):
252254
"uint64_t": "K",
253255
}[ty]
254256

255-
signature = {k: v for k, v in signature.items() if v != 'constexpr'}
256257
args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
257258
format = "iiiOOOOOO" + args_format
258259
signature = ','.join(signature.values()).replace('[', '').replace(']', '')
@@ -262,16 +263,22 @@ def format_of(ty):
262263

263264
# Record the end of regular arguments;
264265
# subsequent arguments are architecture-specific descriptors.
265-
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
266+
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items() if ty != "constexpr")
266267
internal_args_list = []
267268
for i, ty in signature.items():
268-
if ty[0] == "*" or ty == "none":
269+
if ty[0] == "*":
269270
internal_args_list.append(f"ptr_info{i}.dev_ptr")
270-
else:
271+
elif ty != "constexpr":
271272
internal_args_list.append(f"_arg{i}")
272273

273274
# generate glue code
274-
params = [f"&arg{i}" for i, ty in signature.items() if i not in constants and ty != "none"]
275+
newline = '\n '
276+
ptr_decls = [
277+
f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}, stream); if (!ptr_info{i}.valid) return NULL;"
278+
for i, ty in signature.items()
279+
if ty[0] == "*"
280+
]
281+
params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
275282
src = f"""
276283
#include <cstddef>
277284
#include <string>
@@ -394,7 +401,7 @@ def format_of(ty):
394401
assert(num_params == expected_num_params && "number of kernel param not matched");
395402
// Submit the imported kernel.
396403
auto cgf = [&](sycl::handler &cgh) {{
397-
{" ".join(f'set_scalar_arg<{ty_to_cpp(item)}>(cgh, {idx}, params[{idx}]);' for idx, item in enumerate([signature[i] for i in signature if i not in constants and signature[i] != "none"]))}
404+
{" ".join(f'set_scalar_arg<{ty_to_cpp(item)}>(cgh, {idx}, params[{idx}]);' for idx, item in enumerate([signature[i] for i in signature if signature[i] != "constexpr"]))}
398405
if (shared_memory) {{
399406
using share_mem_t = sycl::local_accessor<int8_t, 1>;
400407
share_mem_t local_buffer = share_mem_t(shared_memory, cgh);
@@ -418,7 +425,7 @@ def format_of(ty):
418425
PyObject *py_obj_stream;
419426
PyObject* py_kernel;
420427
421-
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
428+
{newline.join([f"{_extracted_type(ty)} _arg{i};" for i, ty in signature.items()])}
422429
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &py_obj_stream, &py_kernel,
423430
&kernel_metadata, &launch_metadata,
424431
&launch_enter_hook, &launch_exit_hook {args_list})) {{
@@ -467,7 +474,7 @@ def format_of(ty):
467474
if(kernel_ptr == nullptr) return NULL;
468475
sycl::kernel kernel = *kernel_ptr;
469476
470-
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}, stream); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" or ty == "none" else "" for i, ty in signature.items()])};
477+
{newline.join(ptr_decls)}
471478
sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel {',' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
472479
473480
if(launch_exit_hook != Py_None){{
@@ -568,11 +575,11 @@ def serialize_args(args, constants, signature):
568575
class XPULauncher(object):
569576

570577
def __init__(self, src, metadata):
571-
ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()}
572578
constants = src.constants if hasattr(src, "constants") else dict()
573-
self.constants = {idx: value for idx, value in constants.items()}
579+
arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x
580+
self.constants = {arg_idx(idx): value for idx, value in constants.items()}
574581
self.signature = {idx: value for idx, value in src.signature.items()}
575-
src = make_launcher(self.constants, self.signature, ids)
582+
src = make_launcher(self.constants, self.signature)
576583
mod = compile_module_from_src(src, "__triton_launcher")
577584
self.launch = mod.launch
578585

0 commit comments

Comments
 (0)