Skip to content

Commit 6566f6c

Browse files
committed
Merge branch 'main' into etiotto/coalesce_for_block_ptr
2 parents 547d6fa + 0f002cd commit 6566f6c

File tree

70 files changed

+1514
-1245
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+1514
-1245
lines changed

.github/actions/setup-pytorch/action.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ runs:
8282
uses: ./.github/actions/load
8383
env:
8484
# Increase this value to reset cache
85-
CACHE_NUMBER: 11
85+
CACHE_NUMBER: 12
8686
with:
8787
path: pytorch
8888
key: pytorch-$PYTORCH_CACHE_KEY-$CACHE_NUMBER

.github/pins/pytorch-upstream.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
8321eec009c8c79145ebccd51fdfc336e5f8b848
1+
487873f7cafeb0fd390eaefe40496b804bceabbd

.github/workflows/integration-tests.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ name: Integration Tests
1010
on:
1111
workflow_dispatch:
1212
pull_request:
13-
# You can name your branch dev-foo to get CI runs.
14-
branches: [main, 'dev-**']
13+
branches-ignore: ['llvm-**']
1514
merge_group:
1615
branches: [main, 'dev-**']
1716
types: [checks_requested]

.github/workflows/integration-tests.yml.in

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@ name: Integration Tests
99
on:
1010
workflow_dispatch:
1111
pull_request:
12-
# You can name your branch dev-foo to get CI runs.
13-
branches: [main, 'dev-**']
12+
branches-ignore: ['llvm-**']
1413
merge_group:
1514
branches: [main, 'dev-**']
1615
types: [checks_requested]

.github/workflows/llvm-build.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ jobs:
157157
cp -r /usr/aarch64-linux-gnu/lib ./arm-sysroot
158158
cp -r /usr/aarch64-linux-gnu/include ./arm-sysroot
159159
LINKER=$(pwd)/arm-sysroot/lib/ld-linux-aarch64.so.1
160-
wget http://ftp.de.debian.org/debian/pool/main/g/gcc-defaults/gcc-aarch64-linux-gnu_14.1.0-2_amd64.deb
161-
dpkg-deb -x gcc-aarch64-linux-gnu_14.1.0-2_amd64.deb ./arm-sysroot
160+
wget http://ftp.de.debian.org/debian/pool/main/g/gcc-defaults/gcc-aarch64-linux-gnu_14.2.0-1_amd64.deb
161+
dpkg-deb -x gcc-aarch64-linux-gnu_14.2.0-1_amd64.deb ./arm-sysroot
162162
export LD_LIBRARY_PATH=$(pwd)/arm-sysroot/lib:$LD_LIBRARY_PATH
163163
sudo ln -s $LINKER /lib/ld-linux-aarch64.so.1
164164
SYSROOT="$(pwd)/arm-sysroot"

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ python/triton/language/extra
3131
# Proton
3232
python/triton/profiler
3333

34+
# Instrumentation
35+
python/triton/instrumentation
36+
3437
# Python caches
3538
__pycache__/
3639
*.py[cod]

benchmarks/triton_kernels_benchmark/benchmark_driver.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -399,19 +399,77 @@ def format_of(ty):
399399
return src
400400

401401

402+
def serialize_kernel_metadata(arg, args_dict):
403+
args_dict["num_warps"] = arg.num_warps
404+
args_dict["threads_per_warp"] = arg.threads_per_warp
405+
args_dict["shared_memory"] = arg.shared
406+
args_dict["kernel_name"] = arg.name
407+
args_dict["spv_name"] = f"{arg.name}.spv"
408+
409+
410+
def serialize_args(args, constants, signature):
411+
import numbers
412+
dir_path = os.getenv("TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS")
413+
if not os.path.exists(dir_path):
414+
os.makedirs(dir_path)
415+
print(f"Path to directory consisting of SPIR-V Runner data: {dir_path}")
416+
417+
cnt = 0
418+
args_dict = {"gridX": args[cnt], "gridY": args[cnt + 1], "gridZ": args[cnt + 2]}
419+
args_dict["argument_list"] = []
420+
counts = {"tensors": 0, "scalars": 0, "karg_cnt": 0}
421+
cnt = 4
422+
for arg in args[cnt:]:
423+
if type(arg).__name__ == "KernelMetadata":
424+
serialize_kernel_metadata(arg, args_dict)
425+
426+
if isinstance(arg, torch.Tensor):
427+
cpu_tensor = arg.cpu()
428+
tensor_path = os.path.join(dir_path, f"tensor_{counts['tensors']}.pt")
429+
with open(tensor_path, "wb") as f:
430+
torch.save(cpu_tensor, f)
431+
new_arg = {
432+
"name": f"tensor_{counts['tensors']}", "type": "tensor", "dtype": str(arg.dtype), "ctype":
433+
signature[counts["karg_cnt"]]
434+
}
435+
args_dict["argument_list"].append(new_arg)
436+
counts["karg_cnt"] += 1
437+
counts["tensors"] += 1
438+
439+
if isinstance(arg, numbers.Number):
440+
if counts["karg_cnt"] not in constants:
441+
new_arg = {
442+
"name": f"scalarArg_{counts['scalars']}", "type": "scalar", "value": args[cnt], "ctype":
443+
signature[counts["karg_cnt"]]
444+
}
445+
args_dict["argument_list"].append(new_arg)
446+
counts["karg_cnt"] += 1
447+
counts["scalars"] += 1
448+
cnt += 1
449+
# Dump argument info as a JSON file
450+
json_path = os.path.join(dir_path, "args_data.json")
451+
with open(json_path, "w", encoding="utf-8") as json_file:
452+
import json
453+
json.dump(args_dict, json_file, indent=4)
454+
455+
402456
class XPULauncher:
403457

404458
def __init__(self, src, metadata): # pylint: disable=unused-argument
405459
ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()}
406460
constants = src.constants if hasattr(src, "constants") else {}
407461
cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i
408-
constants = {cst_key(key): value for key, value in constants.items()}
409-
signature = {cst_key(key): value for key, value in src.signature.items()}
410-
src = make_launcher(constants, signature, ids)
462+
self.constants = {cst_key(key): value for key, value in constants.items()}
463+
self.signature = {cst_key(key): value for key, value in src.signature.items()}
464+
src = make_launcher(self.constants, self.signature, ids)
411465
mod = compile_module_from_src(src, "__triton_launcher")
412466
self.launch = mod.launch
413467

414468
def __call__(self, *args, **kwargs):
469+
# Serialize KernelArguments for SPIR-V Runner
470+
serialize_kernel_args = os.getenv("TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS", None)
471+
if serialize_kernel_args:
472+
serialize_args(args, self.constants, self.signature)
415473
self.launch(*args, **kwargs)
416474

417475

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def forward(q, k, v, causal, sm_scale):
171171
assert Lk in {16, 32, 64, 128}
172172
o = torch.empty_like(q, dtype=torch.float32)
173173
BLOCK_M = 128
174-
BLOCK_N = 64 if Lk <= 64 else 32
174+
BLOCK_N = 64
175175
num_stages = 3
176176
num_warps = 8 if Lq == 64 else 16
177177
stage = 3 if causal else 1
@@ -205,7 +205,8 @@ def forward(q, k, v, causal, sm_scale):
205205
BLOCK_DMODEL=Lk, #
206206
STAGE=stage, #
207207
num_warps=num_warps, #
208-
num_stages=num_stages #
208+
num_stages=num_stages, #
209+
grf_mode='large', #
209210
)
210211
return o
211212

bin/RegisterTritonDialects.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,6 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
8888
mlir::registerTritonAMDGPUStreamPipeline();
8989
mlir::registerTritonAMDGPUStreamPipelineV2();
9090
mlir::registerTritonAMDGPUCanonicalizePointers();
91-
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();
92-
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();
9391

9492
// TODO: register Triton & TritonGPU passes
9593
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,

cmake/llvm-hash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
61f8a7f618901797ee8663389a29722f29216a96
1+
b5cc222d7429fe6f18c787f633d5262fac2e676f

0 commit comments

Comments
 (0)