Skip to content

Commit 2512ab6

Browse files
authored
Merge branch 'main' into sub-group-shuffle-broadcast
2 parents 70e7eb2 + 61fd54d commit 2512ab6

File tree

61 files changed

+1466
-1216
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+1466
-1216
lines changed

.github/actions/setup-pytorch/action.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ runs:
8282
uses: ./.github/actions/load
8383
env:
8484
# Increase this value to reset cache
85-
CACHE_NUMBER: 11
85+
CACHE_NUMBER: 12
8686
with:
8787
path: pytorch
8888
key: pytorch-$PYTORCH_CACHE_KEY-$CACHE_NUMBER

.github/pins/pytorch-upstream.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
8321eec009c8c79145ebccd51fdfc336e5f8b848
1+
487873f7cafeb0fd390eaefe40496b804bceabbd

.github/workflows/integration-tests.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ name: Integration Tests
1010
on:
1111
workflow_dispatch:
1212
pull_request:
13-
# You can name your branch dev-foo to get CI runs.
14-
branches: [main, 'dev-**']
13+
branches-ignore: ['llvm-**']
1514
merge_group:
1615
branches: [main, 'dev-**']
1716
types: [checks_requested]

.github/workflows/integration-tests.yml.in

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@ name: Integration Tests
99
on:
1010
workflow_dispatch:
1111
pull_request:
12-
# You can name your branch dev-foo to get CI runs.
13-
branches: [main, 'dev-**']
12+
branches-ignore: ['llvm-**']
1413
merge_group:
1514
branches: [main, 'dev-**']
1615
types: [checks_requested]

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ python/triton/language/extra
3131
# Proton
3232
python/triton/profiler
3333

34+
# Instrumentation
35+
python/triton/instrumentation
36+
3437
# Python caches
3538
__pycache__/
3639
*.py[cod]

benchmarks/triton_kernels_benchmark/benchmark_driver.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -399,19 +399,77 @@ def format_of(ty):
399399
return src
400400

401401

402+
def serialize_kernel_metadata(arg, args_dict):
403+
args_dict["num_warps"] = arg.num_warps
404+
args_dict["threads_per_warp"] = arg.threads_per_warp
405+
args_dict["shared_memory"] = arg.shared
406+
args_dict["kernel_name"] = arg.name
407+
args_dict["spv_name"] = f"{arg.name}.spv"
408+
409+
410+
def serialize_args(args, constants, signature):
411+
import numbers
412+
dir_path = os.getenv("TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS")
413+
if not os.path.exists(dir_path):
414+
os.makedirs(dir_path)
415+
print(f"Path to directory consisting of SPIR-V Runner data: {dir_path}")
416+
417+
cnt = 0
418+
args_dict = {"gridX": args[cnt], "gridY": args[cnt + 1], "gridZ": args[cnt + 2]}
419+
args_dict["argument_list"] = []
420+
counts = {"tensors": 0, "scalars": 0, "karg_cnt": 0}
421+
cnt = 4
422+
for arg in args[cnt:]:
423+
if type(arg).__name__ == "KernelMetadata":
424+
serialize_kernel_metadata(arg, args_dict)
425+
426+
if isinstance(arg, torch.Tensor):
427+
cpu_tensor = arg.cpu()
428+
tensor_path = os.path.join(dir_path, f"tensor_{counts['tensors']}.pt")
429+
with open(tensor_path, "wb") as f:
430+
torch.save(cpu_tensor, f)
431+
new_arg = {
432+
"name": f"tensor_{counts['tensors']}", "type": "tensor", "dtype": str(arg.dtype), "ctype":
433+
signature[counts["karg_cnt"]]
434+
}
435+
args_dict["argument_list"].append(new_arg)
436+
counts["karg_cnt"] += 1
437+
counts["tensors"] += 1
438+
439+
if isinstance(arg, numbers.Number):
440+
if counts["karg_cnt"] not in constants:
441+
new_arg = {
442+
"name": f"scalarArg_{counts['scalars']}", "type": "scalar", "value": args[cnt], "ctype":
443+
signature[counts["karg_cnt"]]
444+
}
445+
args_dict["argument_list"].append(new_arg)
446+
counts["karg_cnt"] += 1
447+
counts["scalars"] += 1
448+
cnt += 1
449+
# Dump argument info as a JSON file
450+
json_path = os.path.join(dir_path, "args_data.json")
451+
with open(json_path, "w", encoding="utf-8") as json_file:
452+
import json
453+
json.dump(args_dict, json_file, indent=4)
454+
455+
402456
class XPULauncher:
403457

404458
def __init__(self, src, metadata): # pylint: disable=unused-argument
405459
ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()}
406460
constants = src.constants if hasattr(src, "constants") else {}
407461
cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i
408-
constants = {cst_key(key): value for key, value in constants.items()}
409-
signature = {cst_key(key): value for key, value in src.signature.items()}
410-
src = make_launcher(constants, signature, ids)
462+
self.constants = {cst_key(key): value for key, value in constants.items()}
463+
self.signature = {cst_key(key): value for key, value in src.signature.items()}
464+
src = make_launcher(self.constants, self.signature, ids)
411465
mod = compile_module_from_src(src, "__triton_launcher")
412466
self.launch = mod.launch
413467

414468
def __call__(self, *args, **kwargs):
469+
# Serialize KernelArguments for SPIR-V Runner
470+
serialize_kernel_args = os.getenv("TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS", None)
471+
if serialize_kernel_args:
472+
serialize_args(args, self.constants, self.signature)
415473
self.launch(*args, **kwargs)
416474

417475

bin/RegisterTritonDialects.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,6 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
8888
mlir::registerTritonAMDGPUStreamPipeline();
8989
mlir::registerTritonAMDGPUStreamPipelineV2();
9090
mlir::registerTritonAMDGPUCanonicalizePointers();
91-
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();
92-
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();
9391

9492
// TODO: register Triton & TritonGPU passes
9593
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,

include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,11 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
8888
// encoding not available
8989
return resultVals;
9090
Attribute baseEncoding = encoding;
91-
if (isa<AMDMfmaEncodingAttr>(baseEncoding))
92-
// TODO: this logic seems incorrect for mfma layout. Skip for now.
93-
// We saw mismatches for some flash-attention tests on AMD backend.
94-
// Note that this logic works for sliced layout whose parent is
91+
if (isa<AMDMfmaEncodingAttr>(baseEncoding) ||
92+
isa<AMDWmmaEncodingAttr>(baseEncoding))
93+
// TODO: this logic seems incorrect for mfma and wmma layout. Skip for
94+
// now. We saw mismatches for some flash-attention and dot tests on AMD
95+
// backend. Note that this logic works for sliced layout whose parent is
9596
// mfma layout. Therefore, this is not combined with the following check.
9697
return resultVals;
9798
while (auto sliced = dyn_cast<SliceEncodingAttr>(baseEncoding))

include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,33 +27,15 @@ constexpr int patternBenefitPrioritizeOverLLVMConversions = 10;
2727
constexpr int patternBenefitClampOptimizedPattern = 20;
2828
constexpr int patternBenefitConvertLayoutOptimizedPattern = 20;
2929

30-
struct BackendCallbacks {
31-
/**
32-
* A backend-specific callback for appending auxiliary data during
33-
* `LocalStoreOp` conversion.
34-
*
35-
* @param[in] op The reference to the re-written `LocalStoreOp`.
36-
* @param[in] count The number of issued LLVM instructions.
37-
* @param[in] type The input type of issued LLVM instructions.
38-
*/
39-
std::function<void(triton::gpu::LocalStoreOp op, size_t llvmOpCount,
40-
Type llvmOpType)>
41-
localStoreOpConversion = nullptr;
42-
};
43-
4430
void populateElementwiseOpToLLVMPatterns(
4531
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
4632
ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo,
4733
PatternBenefit benefit);
4834

49-
// The given callback is invoked at the end of a successful rewrite. The
50-
// callback receives 1) the current source op, 2) the number of issued LLVM
51-
// instructions and 3) their input types. Each MLIR backend can provide a
52-
// callback and, thus, handle backend-specific behaviors.
53-
void populateMemoryOpToLLVMPattern(
54-
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
55-
RewritePatternSet &patterns, PatternBenefit benefit,
56-
std::optional<BackendCallbacks> backendCallbacks = std::nullopt);
35+
void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter,
36+
const TargetInfoBase &targetInfo,
37+
RewritePatternSet &patterns,
38+
PatternBenefit benefit);
5739

5840
void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
5941
RewritePatternSet &patterns,

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,11 +1366,11 @@ SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
13661366
Location loc, RewriterBase &rewriter,
13671367
const TargetInfoBase &target);
13681368

1369-
void storeDistributedToShared(
1370-
MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy,
1371-
ArrayRef<Value> srcVals, Value smemBase, ArrayRef<Value> dstStrides,
1372-
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
1373-
std::pair<size_t, Type> *const llvmOpCount = nullptr);
1369+
void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
1370+
Type elemLlvmTy, ArrayRef<Value> srcVals,
1371+
Value smemBase, ArrayRef<Value> dstStrides,
1372+
Location loc, RewriterBase &rewriter,
1373+
const TargetInfoBase &target);
13741374

13751375
inline Value getStructFromSharedMemoryObject(Location loc,
13761376
const SharedMemoryObject &smemObj,

0 commit comments

Comments
 (0)