Skip to content

Commit 26e4cd5

Browse files
committed
Merge branch 'lesh/conda-oct' of https://github.com/intel/intel-xpu-backend-for-triton into lesh/conda-oct
2 parents 2a5ca75 + 5b73c7c commit 26e4cd5

File tree

11 files changed

+462
-147
lines changed

11 files changed

+462
-147
lines changed

test/Analysis/intel/test-alignment.mlir renamed to test/Analysis/intel/test-axis-info.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -876,3 +876,18 @@ module {
876876
tt.return %int_min : i64
877877
}
878878
}
879+
880+
// -----
881+
882+
// CHECK-LABEL: @make_tensor_ptr
883+
tt.func public @make_tensor_ptr(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f8E5M2> {tt.divisibility = 32 : i32}, %arg2: i64 {tt.divisibility = 16 : i32}) {
884+
%c0_i32 = arith.constant 0 : i32
885+
%c1_i64 = arith.constant 1 : i64
886+
%c32_i64 = arith.constant 32 : i64
887+
%c128_i64 = arith.constant 128 : i64
888+
// CHECK: %0 = tt.make_tensor_ptr %arg0, {{.*}} => contiguity = [128, 32], divisibility = [1, 1], constancy = [1, 1], constant_value = <none>
889+
%0 = tt.make_tensor_ptr %arg0, [%c128_i64, %c32_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : !tt.ptr<tensor<128x32xf16>>
890+
// CHECK: %1 = tt.make_tensor_ptr %arg1, {{.*}} => contiguity = [32, 1], divisibility = [16, 1], constancy = [1, 1], constant_value = <none>
891+
%1 = tt.make_tensor_ptr %arg1, [%c32_i64, %c32_i64], [%c1_i64, %arg2], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<64x16xf8E5M2>>
892+
tt.return
893+
}

test/lib/Analysis/intel/TestAxisInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ struct TestAxisInfoPass
1313

1414
StringRef getArgument() const final { return "test-print-axis-info"; }
1515
StringRef getDescription() const final {
16-
return "print the result of the alignment analysis pass";
16+
return "print the result of the axis analysis pass";
1717
}
1818

1919
void runOnOperation() override {

third_party/intel/backend/driver.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -435,19 +435,78 @@ def format_of(ty):
435435
return src
436436

437437

438+
def serialize_kernel_metadata(arg, args_dict):
439+
args_dict['num_warps'] = arg.num_warps
440+
args_dict['threads_per_warp'] = arg.threads_per_warp
441+
args_dict['shared_memory'] = arg.shared
442+
args_dict['kernel_name'] = arg.name
443+
args_dict['spv_name'] = f"{arg.name}.spv"
444+
445+
446+
def serialize_args(args, constants, signature):
447+
import torch
448+
import numbers
449+
dir_path = os.getenv('TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS')
450+
if not os.path.exists(dir_path):
451+
os.makedirs(dir_path)
452+
print(f"Path to directory consisting of SPIR-V Runner data: {dir_path}")
453+
454+
cnt = 0
455+
args_dict = {"gridX": args[cnt], "gridY": args[cnt + 1], "gridZ": args[cnt + 2]}
456+
args_dict['argument_list'] = []
457+
counts = {"tensors": 0, "scalars": 0, "karg_cnt": 0}
458+
cnt = 4
459+
for arg in args[cnt:]:
460+
if type(arg).__name__ == "KernelMetadata":
461+
serialize_kernel_metadata(arg, args_dict)
462+
463+
if isinstance(arg, torch.Tensor):
464+
cpu_tensor = arg.cpu()
465+
tensor_path = os.path.join(dir_path, f"tensor_{counts['tensors']}.pt")
466+
with open(tensor_path, 'wb') as f:
467+
torch.save(cpu_tensor, f)
468+
new_arg = {
469+
"name": f"tensor_{counts['tensors']}", "type": "tensor", "dtype": str(arg.dtype), "ctype":
470+
signature[counts['karg_cnt']]
471+
}
472+
args_dict['argument_list'].append(new_arg)
473+
counts['karg_cnt'] += 1
474+
counts['tensors'] += 1
475+
476+
if isinstance(arg, numbers.Number):
477+
if counts['karg_cnt'] not in constants:
478+
new_arg = {
479+
"name": f"scalarArg_{counts['scalars']}", "type": "scalar", "value": args[cnt], "ctype":
480+
signature[counts['karg_cnt']]
481+
}
482+
args_dict['argument_list'].append(new_arg)
483+
counts['karg_cnt'] += 1
484+
counts['scalars'] += 1
485+
cnt += 1
486+
# Dump argument info as a JSON file
487+
json_path = os.path.join(dir_path, 'args_data.json')
488+
with open(json_path, 'w') as json_file:
489+
import json
490+
json.dump(args_dict, json_file, indent=4)
491+
492+
438493
class XPULauncher(object):
439494

440495
def __init__(self, src, metadata):
441496
ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()}
442497
constants = src.constants if hasattr(src, "constants") else dict()
443498
cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i
444-
constants = {cst_key(key): value for key, value in constants.items()}
445-
signature = {cst_key(key): value for key, value in src.signature.items()}
446-
src = make_launcher(constants, signature, ids)
499+
self.constants = {cst_key(key): value for key, value in constants.items()}
500+
self.signature = {cst_key(key): value for key, value in src.signature.items()}
501+
src = make_launcher(self.constants, self.signature, ids)
447502
mod = compile_module_from_src(src, "__triton_launcher")
448503
self.launch = mod.launch
449504

450505
def __call__(self, *args, **kwargs):
506+
# Serialize KernelArguments for SPIR-V Runner
507+
serialize_kernel_args = os.getenv('TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS', None)
508+
if serialize_kernel_args:
509+
serialize_args(args, self.constants, self.signature)
451510
self.launch(*args, **kwargs)
452511

453512

third_party/intel/lib/Analysis/AxisInfo.cpp

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ int64_t multiplyDivisor(int64_t lhs, int64_t rhs) {
5050
return lhs * rhs;
5151
}
5252

53+
RankedTensorType getRankedTensorType(Type ptrTy) {
54+
return isTensorPointerType(ptrTy)
55+
? cast<RankedTensorType>(cast<PointerType>(ptrTy).getPointeeType())
56+
: dyn_cast<RankedTensorType>(ptrTy);
57+
}
58+
5359
class AxisInfoVisitor {
5460
public:
5561
AxisInfoVisitor() = default;
@@ -409,7 +415,7 @@ class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
409415

410416
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
411417
int dim) override {
412-
auto resTy = dyn_cast<RankedTensorType>(op.getType());
418+
auto resTy = getRankedTensorType(op.getType());
413419
if (!resTy)
414420
return BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
415421
auto shape = resTy.getShape();
@@ -464,7 +470,7 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
464470
private:
465471
int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
466472
int dim) override {
467-
auto resTy = dyn_cast<RankedTensorType>(op.getType());
473+
auto resTy = getRankedTensorType(op.getType());
468474
if (!resTy)
469475
return BinaryOpVisitorImpl<OpTy>::getContiguity(op, lhs, rhs, dim);
470476
auto shape = resTy.getShape();
@@ -498,7 +504,7 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
498504

499505
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
500506
int dim) override {
501-
auto resTy = dyn_cast<RankedTensorType>(op.getType());
507+
auto resTy = getRankedTensorType(op.getType());
502508
if (!resTy)
503509
return BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
504510
auto shape = resTy.getShape();
@@ -647,7 +653,7 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
647653
AxisInfo
648654
getAxisInfo(OpTy op,
649655
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
650-
auto resTy = dyn_cast<RankedTensorType>(op.getType());
656+
auto resTy = getRankedTensorType(op.getType());
651657
if (!resTy)
652658
return AxisInfo();
653659
auto shape = resTy.getShape();
@@ -995,6 +1001,55 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
9951001
}
9961002
};
9971003

1004+
class MakeTensorPtrOpAxisInfoVisitor final
1005+
: public AxisInfoVisitorImpl<triton::MakeTensorPtrOp> {
1006+
public:
1007+
using AxisInfoVisitorImpl<triton::MakeTensorPtrOp>::AxisInfoVisitorImpl;
1008+
1009+
AxisInfo
1010+
getAxisInfo(triton::MakeTensorPtrOp op,
1011+
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
1012+
LDBG("MakeTensorPtrOpAxisInfoVisitor: " << *op);
1013+
assert(op.getShape().size() == 2 && operands.size() == 7 &&
1014+
"MakeTensorPtrOp should have 2D shape");
1015+
1016+
AxisInfo ptrInfo = operands[0]->getValue();
1017+
AxisInfo shapeInfo0 = operands[1]->getValue();
1018+
AxisInfo shapeInfo1 = operands[2]->getValue();
1019+
AxisInfo strideInfo0 = operands[3]->getValue();
1020+
AxisInfo strideInfo1 = operands[4]->getValue();
1021+
1022+
std::optional<int64_t> shape0 = shapeInfo0.getConstantValue();
1023+
std::optional<int64_t> shape1 = shapeInfo1.getConstantValue();
1024+
std::optional<int64_t> stride0 = strideInfo0.getConstantValue();
1025+
std::optional<int64_t> stride1 = strideInfo1.getConstantValue();
1026+
1027+
AxisInfo::DimVectorT contiguity{
1028+
shape0.has_value() && (stride0 == 1) ? shape0.value() : 1,
1029+
shape1.has_value() && (stride1 == 1) ? shape1.value() : 1};
1030+
1031+
int64_t ptrDivisibility = ptrInfo.getDivisibility()[0];
1032+
int64_t strideDivisibility0 = strideInfo0.getDivisibility()[0];
1033+
int64_t strideDivisibility1 = strideInfo1.getDivisibility()[0];
1034+
1035+
LDBG("ptrDivisibility: " << ptrDivisibility);
1036+
LDBG("strideDivisibility0: " << strideDivisibility0);
1037+
LDBG("strideDivisibility1: " << strideDivisibility1);
1038+
1039+
AxisInfo::DimVectorT divisibility{1, 1};
1040+
if (ptrDivisibility > 1) {
1041+
if (contiguity[0] > 1)
1042+
divisibility[0] = std::min(ptrDivisibility, strideDivisibility1);
1043+
if (contiguity[1] > 1)
1044+
divisibility[1] = std::min(ptrDivisibility, strideDivisibility0);
1045+
}
1046+
1047+
AxisInfo::DimVectorT constancy{1, 1};
1048+
1049+
return AxisInfo(contiguity, divisibility, constancy);
1050+
}
1051+
};
1052+
9981053
//===----------------------------------------------------------------------===//
9991054
// AxisInfoAnalysis
10001055
//===----------------------------------------------------------------------===//
@@ -1042,11 +1097,13 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
10421097
MaxMinOpAxisInfoVisitor<arith::MinSIOp>,
10431098
MaxMinOpAxisInfoVisitor<arith::MinUIOp>>();
10441099
visitors.append<LoadOpAxisInfoVisitor>();
1100+
visitors.append<MakeTensorPtrOpAxisInfoVisitor>();
10451101
}
10461102

10471103
LogicalResult AxisInfoAnalysis::visitOperation(
10481104
Operation *op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
10491105
ArrayRef<dataflow::Lattice<AxisInfo> *> results) {
1106+
LDBG("visitOperation: << " << *op);
10501107
// TODO: For sure not the right way to do this
10511108
// but why is scf.if not initialized otherwise?
10521109
for (auto op : operands)
@@ -1204,7 +1261,7 @@ void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp,
12041261
}
12051262

12061263
unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
1207-
auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
1264+
auto tensorTy = getRankedTensorType(ptr.getType());
12081265
if (!tensorTy)
12091266
return 1;
12101267
auto layout = tensorTy.getEncoding();
@@ -1226,7 +1283,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
12261283
}
12271284

12281285
unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
1229-
auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
1286+
auto tensorTy = getRankedTensorType(ptr.getType());
12301287
if (!tensorTy)
12311288
return 1;
12321289
auto *axisInfo = getAxisInfo(ptr);
@@ -1254,7 +1311,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
12541311
}
12551312

12561313
unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
1257-
auto tensorTy = dyn_cast<RankedTensorType>(mask.getType());
1314+
auto tensorTy = getRankedTensorType(mask.getType());
12581315
if (!tensorTy)
12591316
return 1;
12601317
auto *axisInfo = getAxisInfo(mask);

utils/SPIRVRunner/CMakeLists.txt

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,24 @@
11
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
22
project(reproducer)
3-
43
set(CMAKE_CXX_COMPILER icpx)
54
set(BUILD_SHARED_LIBS OFF)
65

76
list(APPEND CMAKE_PREFIX_PATH "/opt/intel/oneapi/tbb/latest/lib/cmake/tbb/")
87

98
find_package(Torch REQUIRED)
109

10+
include(ExternalProject)
11+
ExternalProject_Add(
12+
json
13+
GIT_REPOSITORY https://github.com/nlohmann/json.git
14+
GIT_TAG v3.11.2
15+
PREFIX ${CMAKE_BINARY_DIR}/nlohmann_json
16+
CONFIGURE_COMMAND ""
17+
BUILD_COMMAND ""
18+
INSTALL_COMMAND ""
19+
)
20+
set(JSON_INCLUDE_DIR ${CMAKE_BINARY_DIR}/nlohmann_json/src/json/include/)
21+
1122
# Add preview-breaking-changes for ABI compatibility with SYCL library linked by PyTorch: https://github.com/pytorch/pytorch/commit/92bebb46fa9fd60523d8aeb7b5f1a3f488c4cd93
1223
set(COMPILE_FLAGS "-fsycl -Wall -fpreview-breaking-changes")
1324
set(LINK_FLAGS "-fsycl -lze_loader")
@@ -16,9 +27,10 @@ set(SYCL_FUNCTIONS_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/in
1627

1728
set(TARGET_NAME SPIRVRunner)
1829
add_executable(${TARGET_NAME} ${TARGET_NAME}.cpp)
19-
target_include_directories(${TARGET_NAME} PRIVATE "/opt/intel/oneapi/compiler/latest/include" ${SYCL_FUNCTIONS_INCLUDE_DIR})
30+
target_include_directories(${TARGET_NAME} PRIVATE "/opt/intel/oneapi/compiler/latest/include" ${SYCL_FUNCTIONS_INCLUDE_DIR} ${JSON_INCLUDE_DIR})
2031
set_target_properties(${TARGET_NAME} PROPERTIES COMPILE_FLAGS "${COMPILE_FLAGS}")
2132
set_target_properties(${TARGET_NAME} PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
33+
add_dependencies(${TARGET_NAME} json)
2234

2335
target_link_libraries(${TARGET_NAME} "${TORCH_LIBRARIES}")
2436
set_property(TARGET ${TARGET_NAME} PROPERTY CXX_STANDARD 17)

utils/SPIRVRunner/README.md

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,44 @@ CMAKE_PREFIX_PATH=/abs/path/to/TorchConfig.cmake/FromAbove/ cmake -DCMAKE_BUILD_
1717
make -j
1818
```
1919

20-
## Configuring
20+
## Configuration
2121

22-
`SPIRVRunner` is configured to run the `add_kernel.spv` SPIRV binary with inputs `x.py` and `y.py`. `add_kernel.spv` was generated from the `01-vector-add.py` tutorial.
22+
### Generate Data
2323

24-
Kernels of different shapes require modifying parameters manually in the `SPIRVRunner`. Two places require modification:
24+
In order to utilize this utility, Triton application must be run with following environment variables enabled
25+
Provide the path to the directory where the serialized JSON, tensors and SPRI-V binary stored. It is recommended to clear triton cache.
26+
27+
```
28+
export TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS=< Absolute path to SPV Dumps >
29+
```
30+
31+
Following input data is generated,
32+
33+
1. args_data.json - (Kernel Arguments / Grid Configuration)
34+
2. tensors (Tensors used by the kernel (.pt))
35+
3. SPIR-V binary (.spv)
2536

26-
1. `launchKernel`: Add input Tensors to the function signature, add arguments as variables within the function. Arguments can be pulled from the `args` variable to `XPULauncher.__call__` method in `driver.py`. Arguments should be passed to the `sycl_kernel_launch` function. Note that we currently rely on `sycl::memcpy` to move the PyTorch Tensor to XPU. In later versions of PyTorch we should be able to delegate this responsibility to `PyTorch`, and pass the raw XPU `data_ptr()` from `PyTorch` to the kernel.
27-
2. `sycl_kernel_launch`: Place all `arg*` parameters into the `params` array and add an appropriate call to `set_scalar_arg` for each param, which tells `SYCL` what the arguments are for the kernel we are going to launch.
2837

2938
## Running
3039

31-
Once the `SPIRVRunner` has been appropriately configured for the kernel and inputs, run the binary with no arguments:
40+
Help:
41+
`./build/SPIRVRunner` < Output Tensor Name >
42+
43+
Note: `Output Tensor Name` is essentially a chosen tensor that needs to be copied back to the CPU and written to disk. Additionally, the name must match the tensor's name (tensor_) and number as specified in the JSON file. Please refer args_data.json file.
44+
45+
### Demo (01-vector-add.py)
46+
47+
`SPIRVRunner` is configured to run the `add_kernel.spv` SPIRV binary with inputs `tensor_0.pt` and `tensor_1.pt` and output `tensor_2.pt`. `add_kernel.spv` was generated from the `01-vector-add.py` tutorial.
3248

33-
`./build/SPIRVRunner`
49+
SPIRVRunner Usage:
50+
`./build/SPIRVRunner tensor_2`
3451

3552
Expected output follows:
3653

3754
```
3855
Running on device: Intel(R) Data Center GPU Max 1100
39-
Tensor a: [98432], Float (393728 bytes)
40-
Tensor b: [98432], Float (393728 bytes)
4156
Read 3772 byte kernel.
57+
create kernel:add_kernel
4258
Loaded kernel with 0 registers and 0 register spills.
4359
Tensor output: [98432], Float (393728 bytes)
4460
Kernel return output: 1.37129

0 commit comments

Comments
 (0)