Skip to content

Commit 052234d

Browse files
authored
Merge branch 'main' into lesh/conda-oct
2 parents 5719ed9 + 4c5296d commit 052234d

File tree

41 files changed

+1332
-274
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1332
-274
lines changed

.github/workflows/integration-tests.yml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,9 @@ jobs:
245245
lit -v "${LIT_TEST_DIR}"
246246
- name: Run python tests on CUDA
247247
run: |
248-
SHARED_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/_C"
249-
if [ ! -d "${SHARED_LIB_DIR}" ]; then
250-
echo "Coult not find '${SHARED_LIB_DIR}'" ; exit -1
248+
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/instrumentation"
249+
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
250+
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
251251
fi
252252
cd python/test/unit
253253
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py
@@ -257,7 +257,7 @@ jobs:
257257
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s language/test_line_info.py
258258
# Run hopper/test_flashattention.py separately to avoid out of gpu memory
259259
python3 -m pytest -s hopper/test_flashattention.py
260-
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \
260+
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_DIR}/libGPUInstrumentationTestLib.so \
261261
python3 -m pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py
262262
- name: Run interpreter tests
263263
if: ${{ matrix.runner[0] == 'h100-runner-set' }}
@@ -401,9 +401,9 @@ jobs:
401401
lit -v "${LIT_TEST_DIR}"
402402
- name: Run python tests on HIP
403403
run: |
404-
SHARED_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/_C"
405-
if [ ! -d "${SHARED_LIB_DIR}" ]; then
406-
echo "Coult not find '${SHARED_LIB_DIR}'" ; exit -1
404+
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/instrumentation"
405+
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
406+
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
407407
fi
408408
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
409409
cd python/test/unit
@@ -412,7 +412,7 @@ jobs:
412412
--ignore=test_debug.py
413413
# TODO: uncomment
414414
# pytest --capture=tee-sys -rfs test_debug.py
415-
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \
415+
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_DIR}/libGPUInstrumentationTestLib.so \
416416
pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py
417417
418418
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0

.github/workflows/integration-tests.yml.in

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,9 @@ jobs:
279279

280280
- name: Run python tests on CUDA
281281
run: |
282-
SHARED_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/_C"
283-
if [ ! -d "${SHARED_LIB_DIR}" ]; then
284-
echo "Coult not find '${SHARED_LIB_DIR}'" ; exit -1
282+
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/instrumentation"
283+
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
284+
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
285285
fi
286286
cd python/test/unit
287287
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py
@@ -291,7 +291,7 @@ jobs:
291291
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s language/test_line_info.py
292292
# Run hopper/test_flashattention.py separately to avoid out of gpu memory
293293
python3 -m pytest -s hopper/test_flashattention.py
294-
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \
294+
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_DIR}/libGPUInstrumentationTestLib.so \
295295
python3 -m pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py
296296

297297
- name: Run interpreter tests
@@ -397,9 +397,9 @@ jobs:
397397

398398
- name: Run python tests on HIP
399399
run: |
400-
SHARED_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/_C"
401-
if [ ! -d "${SHARED_LIB_DIR}" ]; then
402-
echo "Coult not find '${SHARED_LIB_DIR}'" ; exit -1
400+
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/instrumentation"
401+
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
402+
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
403403
fi
404404
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
405405
cd python/test/unit
@@ -408,7 +408,7 @@ jobs:
408408
--ignore=test_debug.py
409409
# TODO: uncomment
410410
# pytest --capture=tee-sys -rfs test_debug.py
411-
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \
411+
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_DIR}/libGPUInstrumentationTestLib.so \
412412
pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py
413413

414414
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0

benchmarks/triton_kernels_benchmark/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark, USE_IPEX_OPTION # type: ignore # noqa: F401
1+
from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark, USE_IPEX_OPTION, BENCHMARKING_METHOD # type: ignore # noqa: F401
22

3-
if USE_IPEX_OPTION:
3+
if USE_IPEX_OPTION or BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER":
44
from triton.runtime import driver
55
from . import benchmark_driver
66
# replace the launcher with the profilier hook.

benchmarks/triton_kernels_benchmark/benchmark_driver.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@
1010
from triton.runtime.build import _build, quiet
1111

1212
import torch
13-
import intel_extension_for_pytorch
13+
14+
from .benchmark_testing import USE_IPEX_OPTION
1415

1516
_dirname = os.getenv("ZE_PATH", default="/usr/local")
1617

1718
include_dir = [
1819
os.path.join(_dirname, "include"),
1920
os.path.join(torch.utils.cmake_prefix_path, "../../include"),
20-
os.path.join(torch.utils.cmake_prefix_path, "../../include/torch/csrc/api/include"),
21-
os.path.join(intel_extension_for_pytorch.cmake_prefix_path, "../../include")
21+
os.path.join(torch.utils.cmake_prefix_path, "../../include/torch/csrc/api/include")
2222
]
2323

2424
oneapi_root = os.getenv("ONEAPI_ROOT")
@@ -28,12 +28,15 @@
2828
os.path.join(oneapi_root, "compiler/latest/include/sycl")
2929
]
3030

31-
library_dir = [
32-
os.path.join(_dirname, "lib"),
33-
os.path.join(torch.utils.cmake_prefix_path, "../../lib"),
34-
os.path.join(intel_extension_for_pytorch.cmake_prefix_path, "../../lib")
35-
]
36-
libraries = ["ze_loader", "sycl", "torch", "intel-ext-pt-gpu"]
31+
library_dir = [os.path.join(_dirname, "lib"), os.path.join(torch.utils.cmake_prefix_path, "../../lib")]
32+
libraries = ["ze_loader", "sycl", "torch"]
33+
34+
if USE_IPEX_OPTION:
35+
import intel_extension_for_pytorch
36+
37+
include_dir.append(os.path.join(intel_extension_for_pytorch.cmake_prefix_path, "../../include"))
38+
library_dir.append(os.path.join(intel_extension_for_pytorch.cmake_prefix_path, "../../lib"))
39+
libraries.append("intel-ext-pt-gpu")
3740

3841

3942
def compile_module_from_src(src, name):
@@ -141,6 +144,14 @@ def format_of(ty):
141144
fmt = "iiiOOOOOO" + args_format
142145
args_list = ", " + ", ".join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ""
143146

147+
record_function_header = "#include <ATen/record_function.h>"
148+
ipex_header = ""
149+
xpu_profiler_record = ""
150+
if USE_IPEX_OPTION:
151+
record_function_header = "#include <torch/extension.h>"
152+
ipex_header = "#include <ipex.h>"
153+
xpu_profiler_record = "xpu::profiler_record(kernel_name, event);"
154+
144155
# generate glue code
145156
src = f"""
146157
#include <cstddef>
@@ -149,8 +160,8 @@ def format_of(ty):
149160
#include <iomanip>
150161
#include <level_zero/ze_api.h>
151162
#include <sycl/sycl.hpp>
152-
#include <torch/extension.h>
153-
#include <ipex.h>
163+
{record_function_header}
164+
{ipex_header}
154165
155166
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
156167
#include <Python.h>
@@ -291,7 +302,7 @@ def format_of(ty):
291302
}}
292303
}};
293304
auto event = stream.submit(cgf);
294-
xpu::profiler_record(kernel_name, event);
305+
{xpu_profiler_record}
295306
}}
296307
// end sycl
297308
static PyObject* launch(PyObject* self, PyObject* args) {{

benchmarks/triton_kernels_benchmark/benchmark_testing.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,16 +213,18 @@ def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_no
213213

214214
function_events = prof.events()
215215

216-
functions = []
216+
all_functions = []
217217
if isinstance(kernel_name, str):
218218
kernel_name = [kernel_name]
219219
for ker_name in kernel_name:
220-
functions.extend(list(filter(lambda x: x.name.startswith(ker_name), function_events))) # pylint: disable=cell-var-from-loop
220+
functions = list(filter(lambda x: x.name.startswith(ker_name), function_events)) # pylint: disable=cell-var-from-loop
221+
assert len(functions) == n_repeat, f"the profiling number for kernel: '{ker_name}' not match, {len(functions)}"
222+
all_functions.append(functions)
221223
# profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), function_events)
222224

223-
assert len(functions) == n_repeat, f"the profiling number not match, {len(functions)}"
224225
# Make the time to the milliseconds.
225-
times = torch.tensor([f.self_device_time_total * 1e-3 for f in functions], dtype=torch.float)
226+
times = torch.tensor([sum(map(lambda elem: elem.self_device_time_total, f)) * 1e-3 for f in zip(*all_functions)],
227+
dtype=torch.float)
226228
return _summarize_statistics(times, quantiles, return_mode)
227229

228230

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,10 @@ def benchmark(B, M, N, K, provider):
309309
acc = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
310310
cnt = torch.empty((B, M, N), device='xpu', dtype=torch.int32)
311311
name = f'gemm_shape_{B}_{M}_{K}_{N}'
312+
# FIXME: Use gemm_streamk_benchmark.py when Triton streamk can get
313+
# better performance.
314+
if (B, M, N, K) == (1, 3072, 4096, 3072):
315+
name = 'gemm_streamk_shape_3072_4096_3072'
312316
func = getattr(xetla_kernel, name)
313317
xetla_fn = lambda: func(a, b, c, acc, cnt)
314318
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
@@ -338,6 +342,7 @@ def benchmark(B, M, N, K, provider):
338342
'gemm_shape_32_4096_4096_128': 'Test_32x4096x4096x128_row_row',
339343
'gemm_shape_4096_8_128_16384': 'Test_4096x8x128x16384_row_row',
340344
'gemm_shape_4096_8_16384_128': 'Test_4096x8x16384x128_row_row',
345+
'gemm_streamk_shape_3072_4096_3072': 'stream_k_gemm_run',
341346
}
342347

343348
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')

benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,9 +293,8 @@ def benchmark(M, N, K, provider):
293293
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
294294

295295
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
296-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(
297-
xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
298-
kernel_name='gpu::xetla::kernel::gemm_universal_t<dispatch_stream_k')
296+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10,
297+
quantiles=quantiles, kernel_name='stream_k_gemm_run')
299298
else:
300299
raise NotImplementedError(f'Unsupported provider {provider}')
301300

bin/RegisterTritonDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
8888
mlir::registerTritonAMDGPUStreamPipeline();
8989
mlir::registerTritonAMDGPUStreamPipelineV2();
9090
mlir::registerTritonAMDGPUCanonicalizePointers();
91+
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();
92+
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();
9193

9294
// TODO: register Triton & TritonGPU passes
9395
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,

include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,33 @@ constexpr int patternBenefitPrioritizeOverLLVMConversions = 10;
2727
constexpr int patternBenefitClampOptimizedPattern = 20;
2828
constexpr int patternBenefitConvertLayoutOptimizedPattern = 20;
2929

30+
struct BackendCallbacks {
31+
/**
32+
* A backend-specific callback for appending auxiliary data during
33+
* `LocalStoreOp` conversion.
34+
*
35+
* @param[in] op The reference to the re-written `LocalStoreOp`.
36+
* @param[in] count The number of issued LLVM instructions.
37+
* @param[in] type The input type of issued LLVM instructions.
38+
*/
39+
std::function<void(triton::gpu::LocalStoreOp op, size_t llvmOpCount,
40+
Type llvmOpType)>
41+
localStoreOpConversion = nullptr;
42+
};
43+
3044
void populateElementwiseOpToLLVMPatterns(
3145
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
3246
ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo,
3347
PatternBenefit benefit);
3448

35-
void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter,
36-
const TargetInfoBase &targetInfo,
37-
RewritePatternSet &patterns,
38-
PatternBenefit benefit);
49+
// The given callback is invoked at the end of a successful rewrite. The
50+
// callback receives 1) the current source op, 2) the number of issued LLVM
51+
// instructions and 3) their input types. Each MLIR backend can provide a
52+
// callback and, thus, handle backend-specific behaviors.
53+
void populateMemoryOpToLLVMPattern(
54+
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
55+
RewritePatternSet &patterns, PatternBenefit benefit,
56+
std::optional<BackendCallbacks> backendCallbacks = std::nullopt);
3957

4058
void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
4159
RewritePatternSet &patterns,

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,11 +1366,11 @@ SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
13661366
Location loc, RewriterBase &rewriter,
13671367
const TargetInfoBase &target);
13681368

1369-
void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
1370-
Type elemLlvmTy, ArrayRef<Value> srcVals,
1371-
Value smemBase, ArrayRef<Value> dstStrides,
1372-
Location loc, RewriterBase &rewriter,
1373-
const TargetInfoBase &target);
1369+
void storeDistributedToShared(
1370+
MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy,
1371+
ArrayRef<Value> srcVals, Value smemBase, ArrayRef<Value> dstStrides,
1372+
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
1373+
std::pair<size_t, Type> *const llvmOpCount = nullptr);
13741374

13751375
inline Value getStructFromSharedMemoryObject(Location loc,
13761376
const SharedMemoryObject &smemObj,

0 commit comments

Comments
 (0)