Skip to content

Commit e488dea

Browse files
committed
Merge branch 'main' of https://github.com/intel/intel-xpu-backend-for-triton into amyachev/device
2 parents 6c4df59 + 6588f0d commit e488dea

File tree

52 files changed

+1167
-1140
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+1167
-1140
lines changed

.github/actions/setup-pytorch/action.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ runs:
120120
cd pytorch
121121
pip install wheel
122122
pip install -r requirements.txt
123-
USE_STATIC_MKL=1 python setup.py bdist_wheel
123+
USE_STATIC_MKL=1 CFLAGS="-Wno-error=maybe-uninitialized" python setup.py bdist_wheel
124124
125125
- name: Install PyTorch (built from source)
126126
if: ${{ inputs.mode == 'source' }}

.github/workflows/integration-tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ jobs:
236236
- name: Install pip dependencies
237237
run: |
238238
python3 -m pip install --upgrade pip
239-
python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-forked pytest-xdist lit
239+
python3 -m pip install cython setuptools wheel cmake==3.24 ninja lit
240240
- name: Install Triton
241241
env:
242242
CUDA_HOME: "/usr/local/cuda"
@@ -569,7 +569,7 @@ jobs:
569569
python3 -m venv ~/.venv
570570
source ~/.venv/bin/activate
571571
python3 -m pip install --upgrade pip
572-
python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-xdist lit pybind11
572+
python3 -m pip install cython setuptools wheel cmake==3.24 ninja lit pybind11
573573
- name: Install Triton
574574
env:
575575
TRITON_BUILD_WITH_O1: "true"

.github/workflows/integration-tests.yml.in

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ jobs:
268268
- name: Install pip dependencies
269269
run: |
270270
python3 -m pip install --upgrade pip
271-
python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-forked pytest-xdist lit
271+
python3 -m pip install cython setuptools wheel cmake==3.24 ninja lit
272272

273273
- name: Install Triton
274274
env:
@@ -481,7 +481,7 @@ jobs:
481481
python3 -m venv ~/.venv
482482
source ~/.venv/bin/activate
483483
python3 -m pip install --upgrade pip
484-
python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-xdist lit pybind11
484+
python3 -m pip install cython setuptools wheel cmake==3.24 ninja lit pybind11
485485
- name: Install Triton
486486
env:
487487
TRITON_BUILD_WITH_O1: "true"

.github/workflows/triton-benchmarks.yml

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -158,21 +158,6 @@ jobs:
158158
python ../../scripts/build_report.py $REPORTS/matmul-performance-base.csv $REPORTS/gemm-triton-report.csv --benchmark gemm --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
159159
python ../../scripts/build_report.py $REPORTS/matmul-performance-base.csv $REPORTS/gemm-xetla-report.csv --benchmark gemm --compiler xetla --param_cols "B,M,K,N" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
160160
161-
- name: Run Triton GEMM kernel benchmark - default path
162-
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_benchmark.py_default') }}
163-
run: |
164-
cd benchmarks/triton_kernels_benchmark
165-
# Default path:
166-
TRITON_INTEL_ADVANCED_PATH=0 \
167-
IGC_VISAOptions=" -enableBCR -nolocalra" \
168-
IGC_DisableLoopUnroll=1 \
169-
python gemm_benchmark.py --reports $REPORTS
170-
mv $REPORTS/matmul-performance.csv $REPORTS/matmul-performance-default-path.csv
171-
172-
source ../../scripts/capture-hw-details.sh
173-
TAG="${TAG}-dflt"
174-
python ../../scripts/build_report.py $REPORTS/matmul-performance-default-path.csv $REPORTS/gemm-triton-default-report.csv --benchmark gemm --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
175-
176161
- name: Run Triton GEMM kernel benchmark - advanced path
177162
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_benchmark.py_advanced') }}
178163
run: |
@@ -260,19 +245,6 @@ jobs:
260245
python ../../scripts/build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-triton-report.csv --benchmark attn --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
261246
python ../../scripts/build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-xetla-report.csv --benchmark attn --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
262247
263-
- name: Run Triton FA kernel benchmark - default path
264-
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmark || '[]'), 'flash_attention_fwd_benchmark.py_default') }}
265-
run: |
266-
cd benchmarks/triton_kernels_benchmark
267-
TRITON_INTEL_ADVANCED_PATH=0 \
268-
TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT=1 \
269-
IGC_VISAOptions=" -enableBCR" \
270-
python flash_attention_fwd_benchmark.py --reports $REPORTS
271-
272-
TAG="${TAG}-dflt"
273-
source ../../scripts/capture-hw-details.sh
274-
python ../../scripts/build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-triton-default-report.csv --benchmark attn --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
275-
276248
- name: Run Triton FA kernel benchmark - advanced path
277249
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flash_attention_fwd_benchmark.py_advanced') }}
278250
run: |

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
185185
include_directories(${PYTHON_SRC_PATH})
186186

187187
# Python Interpreter is used to run lit tests
188-
find_package(Python3 REQUIRED COMPONENTS Development Interpreter)
188+
find_package(Python3 REQUIRED COMPONENTS Development.Module Interpreter)
189189
find_package(pybind11 CONFIG REQUIRED HINTS "${Python3_SITELIB}")
190190

191191
if (DEFINED TRITON_PLUGIN_DIRS)

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -227,28 +227,28 @@ def get_shapes(B, M, N, K, transpose_a, transpose_b):
227227
@benchmark_suit.perf_report(
228228
benchmark_suit.Benchmark(
229229
# argument names to use as an x-axis for the plot
230-
x_names=['B', 'M', 'K', 'N'],
230+
x_names=['B', 'M', 'N', 'K'],
231231
# different possible values for `x_name`
232232
x_vals=[[1, 1024 * i, 1024 * i, 1024 * i] for i in [1, 2, 4, 8]] + #
233233
[ #
234-
[1, 1, 5120, 13824], #
235-
[1, 4, 4096, 12288], #
234+
[1, 1, 13824, 5120], #
235+
[1, 4, 12288, 4096], #
236236
[1, 512, 8192, 8192], #
237237
[1, 512, 8192, 32768], #
238238
[1, 512, 32768, 8192], #
239-
[1, 1024, 16384, 8192], #
240-
[1, 1024, 28672, 8192], #
241-
[1, 3072, 4096, 3072], # FIXME: Remove this case when gemm_streamk_benchmark can get better performance
242-
[1, 4096, 16384, 8192], #
243-
[1, 8192, 16384, 1024], #
244-
[1, 8192, 16384, 4096], #
239+
[1, 1024, 8192, 16384], #
240+
[1, 1024, 8192, 28672], #
241+
[1, 3072, 3072, 4096], # FIXME: Remove this case when gemm_streamk_benchmark can get better performance
242+
[1, 4096, 8192, 16384], #
243+
[1, 8192, 1024, 16384], #
244+
[1, 8192, 4096, 16384], #
245245
[1, 16384, 1024, 8192], #
246246
[1, 16384, 4096, 8192], #
247247
[1, 16384, 8192, 1024], #
248248
[1, 16384, 8192, 4096], #
249249
[4, 32768, 128, 4096], #
250250
[4, 32768, 4096, 128], #
251-
[32, 4096, 4096, 128], #
251+
[32, 4096, 128, 4096], #
252252
[4096, 8, 128, 16384], #
253253
[4096, 8, 16384, 128]
254254
],
@@ -268,6 +268,7 @@ def get_shapes(B, M, N, K, transpose_a, transpose_b):
268268
def benchmark(B, M, N, K, provider):
269269
a_shape, b_shape = get_shapes(B, M, N, K, transpose_a=TRANSPOSE_A, transpose_b=TRANSPOSE_B)
270270

271+
torch.manual_seed(0)
271272
a = torch.rand(a_shape, device='xpu', dtype=torch.bfloat16)
272273
b = torch.rand(b_shape, device='xpu', dtype=torch.bfloat16)
273274

@@ -291,10 +292,10 @@ def benchmark(B, M, N, K, provider):
291292
elif provider == 'triton':
292293
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
293294
if len(a.shape) == 3:
294-
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
295+
c = torch.zeros((B, M, N), device='xpu', dtype=torch.float32)
295296
else:
296297
assert len(a.shape) == 2, 'Expecting shape of length 2'
297-
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
298+
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
298299
triton_fn = lambda: matmul(a, b, c, transpose_a=TRANSPOSE_A, transpose_b=TRANSPOSE_B)
299300
torch_fn = lambda: torch.matmul(torch_a, torch_b).to(torch.float32)
300301
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
@@ -304,17 +305,17 @@ def benchmark(B, M, N, K, provider):
304305
kernel_name='matmul_kernel_with_block_pointers')
305306
elif provider == 'xetla':
306307
if B == 1:
307-
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
308-
acc = torch.empty((M, N), device='xpu', dtype=torch.float32)
309-
cnt = torch.empty((M, N), device='xpu', dtype=torch.int32)
308+
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
309+
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
310+
cnt = torch.zeros((M, N), device='xpu', dtype=torch.int32)
310311
else:
311-
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
312-
acc = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
313-
cnt = torch.empty((B, M, N), device='xpu', dtype=torch.int32)
312+
c = torch.zeros((B, M, N), device='xpu', dtype=torch.float32)
313+
acc = torch.zeros((B, M, N), device='xpu', dtype=torch.float32)
314+
cnt = torch.zeros((B, M, N), device='xpu', dtype=torch.int32)
314315
name = f'gemm_shape_{B}_{M}_{K}_{N}'
315316
# FIXME: Use gemm_streamk_benchmark.py when Triton streamk can get
316317
# better performance.
317-
if (B, M, N, K) == (1, 3072, 4096, 3072):
318+
if (B, M, N, K) == (1, 3072, 3072, 4096):
318319
name = 'gemm_streamk_shape_3072_4096_3072'
319320
func = getattr(xetla_kernel, name)
320321
xetla_fn = lambda: func(a, b, c, acc, cnt)

bin/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ target_link_libraries(triton-opt PRIVATE
1313
TritonTransforms
1414
TritonGPUTransforms
1515
TritonNvidiaGPUTransforms
16+
TritonIntelLLVMIR
1617
MLIRGPUToROCDLTransforms
1718
${dialect_libs}
1819
${conversion_libs}
@@ -88,6 +89,7 @@ target_link_libraries(triton-llvm-opt PRIVATE
8889
LLVMSupport
8990
LLVMOption
9091
LLVMCodeGen
92+
TritonIntelLLVMIR
9193
TritonIntelGPUIR
9294
)
9395
export_executable_symbols_for_plugins(triton-llvm-opt)

bin/triton-llvm-opt.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/// Trimmed down clone of llvm opt to be able to test triton custom llvm ir
22
/// passes.
33
#include "lib/Target/LLVMIR/LLVMPasses.h"
4+
#include "third_party/intel/lib/LLVMIR/LLVMPasses.h"
45
#include "llvm/CodeGen/CommandFlags.h"
56
#include "llvm/IR/Constants.h"
67
#include "llvm/IR/DataLayout.h"
@@ -42,6 +43,11 @@ static cl::opt<bool>
4243
llvm::cl::desc("run pass to break phi struct"),
4344
cl::init(false));
4445

46+
static cl::opt<bool> FreezeMaskedDivRem(
47+
"freeze-masked-div-rem",
48+
llvm::cl::desc("run pass to insert freeze between masked load and div/rem"),
49+
cl::init(false));
50+
4551
namespace {
4652
static std::function<Error(Module *)> makeOptimizingPipeline() {
4753
return [](Module *m) -> Error {
@@ -62,6 +68,8 @@ static std::function<Error(Module *)> makeOptimizingPipeline() {
6268
llvm::FunctionPassManager fpm;
6369
if (BreakStructPhiNodes)
6470
fpm.addPass(BreakStructPhiNodesPass());
71+
if (FreezeMaskedDivRem)
72+
fpm.addPass(FreezeMaskedDivRemPass());
6573
mpm.addPass(createModuleToFunctionPassAdaptor(std::move(fpm)));
6674
mpm.run(*m, mam);
6775
return Error::success();

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
3535
// Return the minClusterId and maxClusterId for the given ForOp.
3636
std::pair<int, int> getMinMaxCluster(scf::ForOp &forOp);
3737
std::pair<int, int> getStageCluster(Operation *op);
38-
void setStageCluster(scf::ForOp &forOp, Operation *op, int stage, int cluster);
38+
void setStageCluster(Operation *op, int stage, int cluster);
3939
} // namespace triton
4040
} // namespace mlir
4141

lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,7 @@ class OpBuilderWithStage : public OpBuilder {
6464
OpTy createWithStage(Location location, int stage, int cluster,
6565
Args &&...args) {
6666
OpTy op = OpBuilder::create<OpTy>(location, std::forward<Args>(args)...);
67-
auto ctx = getContext();
68-
op->setAttr(mlir::triton::kLoopStageAttrName,
69-
IntegerAttr::get(IntegerType::get(ctx, 32), stage));
70-
op->setAttr(mlir::triton::kLoopClusterAttrName,
71-
IntegerAttr::get(IntegerType::get(ctx, 32), cluster));
67+
tt::setStageCluster(op, stage, cluster);
7268
return op;
7369
}
7470
using OpBuilder::create;
@@ -204,9 +200,8 @@ static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,
204200
// Prefetch load if is not MMAV3 and is used by the dot.
205201
if (loadToInfo[loadOp].usedByDot) {
206202
assert(stageForFirstUse >= 1);
207-
tt::setStageCluster(forOp, wait, stageForFirstUse - 1, maxClusterId + 1);
208-
tt::setStageCluster(forOp, viewLoad, stageForFirstUse - 1,
209-
maxClusterId + 1);
203+
tt::setStageCluster(wait, stageForFirstUse - 1, maxClusterId + 1);
204+
tt::setStageCluster(viewLoad, stageForFirstUse - 1, maxClusterId + 1);
210205
retCode = stageForFirstUse - 1;
211206
}
212207
}

0 commit comments

Comments
 (0)