Skip to content

Commit 04d3349

Browse files
authored
Merge branch 'main' into amyachev/issue2736
2 parents 64cbed5 + cc1d4c5 commit 04d3349

File tree

11 files changed

+93
-11
lines changed

11 files changed

+93
-11
lines changed

.github/workflows/build-test-gpu.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ on:
3939
description: Ignore pytest.skip
4040
type: boolean
4141
default: false
42+
use_system_python:
43+
description: Use system Python
44+
type: boolean
45+
default: false
4246

4347
permissions: read-all
4448

@@ -60,3 +64,4 @@ jobs:
6064
skip_list: ${{ inputs.skip_list }}
6165
run_name: ${{ inputs.run_name || format('Build and test {0}', inputs.runner_label) }}
6266
enable_unskip: ${{ inputs.enable_unskip }}
67+
use_system_python: ${{ inputs.use_system_python || false }}

.github/workflows/build-test-reusable.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ on:
5656
description: Runner label for version
5757
type: string
5858
default: runner-0.0.20
59+
use_system_python:
60+
description: Use system Python
61+
type: boolean
62+
default: false
5963

6064
permissions: read-all
6165

@@ -91,10 +95,16 @@ jobs:
9195
key: pip-${{ inputs.python_version }}-${{ hashFiles('python/pyproject.toml', 'python/setup.py') }}-${{ env.CACHE_NUMBER }}
9296

9397
- name: Install Python ${{ inputs.python_version }}
98+
if: ${{ !inputs.use_system_python }}
9499
uses: actions/setup-python@v5
95100
with:
96101
python-version: ${{ inputs.python_version }}
97102

103+
- name: Identify Python version
104+
run: |
105+
PYTHON_VERSION="$(python -c 'import sys; print(f"{sys.version_info[0]}.{ sys.version_info[1]}")')"
106+
echo "PYTHON_VERSION=$PYTHON_VERSION" | tee -a $GITHUB_ENV
107+
98108
- name: Setup PyTorch
99109
uses: ./.github/actions/setup-pytorch
100110
with:

.github/workflows/triton-benchmarks.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ on:
2828
description: JSON list of benchmarks to skip
2929
type: string
3030
default: "[]"
31+
use_system_python:
32+
description: Use system Python
33+
type: boolean
34+
default: false
3135
schedule:
3236
- cron: "5 23 * * *"
3337
pull_request:
@@ -67,10 +71,16 @@ jobs:
6771
key: pip-$PYTHON_VERSION-$GITHUB_SHA
6872

6973
- name: Install Python
74+
if: ${{ !(inputs.use_system_python || false) }}
7075
uses: actions/setup-python@v5
7176
with:
7277
python-version: ${{ env.PYTHON_VERSION }}
7378

79+
- name: Identify Python version
80+
run: |
81+
PYTHON_VERSION="$(python -c 'import sys; print(f"{sys.version_info[0]}.{ sys.version_info[1]}")')"
82+
echo "PYTHON_VERSION=$PYTHON_VERSION" | tee -a $GITHUB_ENV
83+
7484
- name: Install Python build dependencies
7585
run: |
7686
pip install wheel cmake

benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def forward(ctx, a, b, c, acc_dtype=None):
128128
[512, 32768, 8192],
129129
[1024, 28672, 8192],
130130
[3072, 4096, 3072],
131+
[4096, 4096, 4096],
131132
],
132133
line_arg='provider',
133134
# argument name whose value corresponds to a different line in the plot
@@ -152,17 +153,17 @@ def benchmark(M, N, K, provider):
152153
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10,
153154
quantiles=quantiles)
154155
elif provider == 'triton':
155-
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
156+
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
156157
triton_fn = lambda: matmul(a, b, c)
157158
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
158159
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
159160
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
160161
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
161162
quantiles=quantiles, kernel_name='_kernel')
162163
elif provider == 'xetla':
163-
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
164-
acc = torch.empty((M, N), device='xpu', dtype=torch.float32)
165-
cnt = torch.empty((M, N), device='xpu', dtype=torch.int32)
164+
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
165+
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
166+
cnt = torch.zeros((M, N), device='xpu', dtype=torch.int32)
166167

167168
name = f'gemm_splitk_shape_{M}_{K}_{N}'
168169
func = getattr(xetla_kernel, name)

benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,17 +275,17 @@ def benchmark(M, N, K, provider):
275275
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10,
276276
quantiles=quantiles)
277277
elif provider == 'triton':
278-
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
278+
c = torch.zeros((M, N), device=a.device, dtype=torch.float32)
279279
triton_fn = lambda: matmul(a, b, c)
280280
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
281281
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch')
282282
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
283283
quantiles=quantiles,
284284
kernel_name=['first_wave', 'full_tiles'])
285285
elif provider == 'xetla':
286-
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
287-
acc = torch.empty((M, N), device='xpu', dtype=torch.float32)
288-
cnt = torch.empty((M, N), device='xpu', dtype=torch.int32)
286+
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
287+
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
288+
cnt = torch.zeros((M, N), device='xpu', dtype=torch.int32)
289289

290290
name = f'gemm_streamk_shape_{M}_{K}_{N}'
291291
func = getattr(xetla_kernel, name)

benchmarks/xetla_kernel/python_main.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,9 @@ PYBIND11_MODULE(xetla_kernel, m) {
317317
m.def("gemm_splitk_shape_3072_4096_3072",
318318
&bf16_split_k_gemm<3072, 4096, 3072, kslicing_impl_t::global>,
319319
"bf16_gemm_splitk (XeTLA)");
320+
m.def("gemm_splitk_shape_4096_4096_4096",
321+
&bf16_split_k_gemm<4096, 4096, 4096, kslicing_impl_t::global>,
322+
"bf16_gemm_splitk (XeTLA)");
320323
// flash_attn
321324
m.def("flash_attn_causal_false", &flash_attn<false, false, false>,
322325
"flash attn fwd (XeTLA)");

cmake/llvm-hash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
fa57c7a6a5f594a9e3ae2dbe3542cf89a20cdd73
1+
bd9145c8c21334e099d51b3e66f49d51d24931ee

python/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def get_llvm_package_info():
207207
with open(llvm_hash_path, "r") as llvm_hash_file:
208208
rev = llvm_hash_file.read(8)
209209
name = f"llvm-{rev}-{system_suffix}"
210-
url = f"https://github.com/intel/intel-xpu-backend-for-triton/releases/download/llvm-{rev}/{name}.tar.gz"
210+
url = f"https://oaitriton.blob.core.windows.net/public/llvm-builds/{name}.tar.gz"
211211
return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
212212

213213

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1571,7 +1571,7 @@ struct AtomicRMWOpConversion
15711571
auto lowPtrBits = and_(intPtr, i64_val(3));
15721572
auto elemIndex = trunc(i32_ty, lshr(lowPtrBits, i64_val(1)));
15731573
auto alignPtr = inttoptr(rmwPtr.getType(), sub(intPtr, lowPtrBits));
1574-
auto firstValInt = load(i32_ty, alignPtr, 4, false, false, false,
1574+
auto firstValInt = load(i32_ty, alignPtr, 4, false, false, false, false,
15751575
LLVM::AtomicOrdering::acquire);
15761576

15771577
// Create a loop body block. It has a single parameter which holds the

third_party/nvidia/backend/driver.c

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
#include "cuda.h"
2+
#ifdef WIN32
3+
#define WIN32_LEAN_AND_MEAN
4+
#define NOMINMAX
5+
#include <windows.h>
6+
#else
27
#include <dlfcn.h>
8+
#endif
39
#include <stdbool.h>
410
#define PY_SSIZE_T_CLEAN
511
#include <Python.h>
@@ -161,6 +167,27 @@ typedef CUresult (*cuTensorMapEncodeTiled_t)(
161167
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion,
162168
CUtensorMapFloatOOBfill oobFill);
163169

170+
#ifdef WIN32
171+
#define defineGetFunctionHandle(name, symbolName) \
172+
static symbolName##_t name() { \
173+
/* Open the shared library */ \
174+
HMODULE handle = LoadLibraryA("nvcuda.dll"); \
175+
if (!handle) { \
176+
PyErr_SetString(PyExc_RuntimeError, "Failed to open nvcuda.dll"); \
177+
return NULL; \
178+
} \
179+
symbolName##_t funcHandle = \
180+
(symbolName##_t)GetProcAddress((HMODULE)handle, #symbolName); \
181+
/* Check for errors */ \
182+
long err = GetLastError(); \
183+
if (err) { \
184+
PyErr_SetString(PyExc_RuntimeError, \
185+
"Failed to retrieve " #symbolName " from nvcuda.dll"); \
186+
return NULL; \
187+
} \
188+
return funcHandle; \
189+
}
190+
#else
164191
#define defineGetFunctionHandle(name, symbolName) \
165192
static symbolName##_t name() { \
166193
/* Open the shared library */ \
@@ -182,6 +209,7 @@ typedef CUresult (*cuTensorMapEncodeTiled_t)(
182209
} \
183210
return funcHandle; \
184211
}
212+
#endif
185213

186214
defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle,
187215
cuOccupancyMaxActiveClusters);

0 commit comments

Comments
 (0)