Skip to content

Commit 1a9d656

Browse files
committed
Merge remote-tracking branch 'origin/main' into etiotto.remove_masks
2 parents fbdbee5 + aa01f5c commit 1a9d656

File tree

7 files changed

+23
-18
lines changed

7 files changed

+23
-18
lines changed

.github/workflows/build-test-reusable.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ jobs:
302302
run: |
303303
echo "TRITON_TEST_CMD=${{ needs.build.outputs.test-triton-command }}" | tee -a $GITHUB_ENV
304304
305-
- name: Build PTI && Run Proton tests
305+
- name: Run Proton tests
306306
if: matrix.suite == 'rest'
307307
run: |
308308
export LD_LIBRARY_PATH=${{ env.PTI_LIBS_DIR }}:$LD_LIBRARY_PATH

.github/workflows/triton-benchmarks.yml

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -117,17 +117,10 @@ jobs:
117117
cd benchmarks
118118
pip install .
119119
120-
- name: Build PTI from source
120+
- name: Build PTI
121121
run: |
122-
PTI_COMMIT_ID="$(<.github/pins/pti.txt)"
123-
git clone https://github.com/intel/pti-gpu.git
124-
cd pti-gpu
125-
git checkout $PTI_COMMIT_ID
126-
cd sdk
127-
cmake --preset linux-icpx-release
128-
BUILD_TESTING=1 PTI_BUILD_SAMPLES=1 cmake --build --preset linux-icpx-release
129-
130-
PTI_LIBS_DIR="$(pwd)/build-linux-icpx-release/lib/"
122+
./scripts/install-pti.sh --build-level-zero
123+
PTI_LIBS_DIR=$(python ./scripts/pti_lib.py)
131124
ls $PTI_LIBS_DIR
132125
echo "PTI_LIBS_DIR=$PTI_LIBS_DIR" >> $GITHUB_ENV
133126

python/test/unit/tools/test_disasm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,4 @@ def kernel(X, i: tl.constexpr):
3434
assert x[0] == 12
3535
dis = h.asm["spvdis"]
3636
# check that the spvdis has a store instruction.
37-
assert "PredicatedStore" in dis
37+
assert "OpStore" in dis

python/triton/compiler/compiler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,12 @@ def raise_(err):
468468
self.module, self.function, self.n_regs, self.n_spills, self.n_max_threads = driver.active.utils.load_binary(
469469
self.name, self.kernel, self.metadata.shared, self.metadata.build_flags,
470470
not self.metadata.generate_native_code, device)
471+
# PyTorch could use the updated build flags in load binary.
472+
if hasattr(driver.active.utils, "get_last_selected_build_flags"):
473+
new_build_flags = driver.active.utils.get_last_selected_build_flags()
474+
if new_build_flags != self.metadata.build_flags:
475+
self.metadata = self.metadata._replace(build_flags=new_build_flags)
476+
471477
if hasattr(self.metadata, "threads_per_warp"):
472478
warp_size = self.metadata.threads_per_warp
473479
else:

third_party/intel/backend/driver.c

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,12 @@ sycl::context get_default_context(const sycl::device &sycl_device) {
192192
#endif
193193
}
194194

195+
static BuildFlags last_build_flag("");
196+
197+
extern "C" EXPORT_FUNC PyObject *get_last_selected_build_flags() {
198+
return Py_BuildValue("s", last_build_flag().data());
199+
}
200+
195201
extern "C" EXPORT_FUNC PyObject *load_binary(PyObject *args) {
196202
const char *name, *build_flags_ptr;
197203
int shared;
@@ -309,7 +315,7 @@ extern "C" EXPORT_FUNC PyObject *load_binary(PyObject *args) {
309315
PyCapsule_New(reinterpret_cast<void *>(fun), "kernel", freeKernel);
310316
auto kernel_bundle_py = PyCapsule_New(reinterpret_cast<void *>(mod),
311317
"kernel_bundle", freeKernelBundle);
312-
318+
last_build_flag = build_flags;
313319
return Py_BuildValue("(OOiii)", kernel_bundle_py, kernel_py, n_regs,
314320
n_spills, n_max_threads);
315321

third_party/intel/backend/driver.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,11 @@ def __init__(self, cache_path: str):
199199
self.shared_library.get_device_properties.argtypes = (ctypes.c_int, )
200200
self.shared_library.has_opencl_extension.restype = ctypes.py_object
201201
self.shared_library.has_opencl_extension.argtypes = (ctypes.c_int, ctypes.c_char_p)
202+
self.shared_library.get_last_selected_build_flags.restype = ctypes.py_object
202203

203204
def __getattribute__(self, name):
204-
if name in ("get_device_properties", "init_devices", "wait_on_sycl_queue", "has_opencl_extension"):
205+
if name in ("get_device_properties", "init_devices", "wait_on_sycl_queue", "has_opencl_extension",
206+
"get_last_selected_build_flags"):
205207
shared_library = super().__getattribute__("shared_library")
206208
return getattr(shared_library, name)
207209

@@ -318,6 +320,7 @@ def __init__(self):
318320
self.device_count = mod.init_devices(self.get_sycl_queue())
319321
self.wait_on_sycl_queue = mod.wait_on_sycl_queue
320322
self.has_opencl_extension = mod.has_opencl_extension
323+
self.get_last_selected_build_flags = mod.get_last_selected_build_flags
321324

322325
def get_current_device(self):
323326
import torch

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3401,10 +3401,7 @@ struct StoreOpConversion
34013401

34023402
if (maskVal) {
34033403
// Create a predicated store operation.
3404-
std::optional<bool> enablePredicated =
3405-
mlir::triton::tools::isEnvValueBool(
3406-
mlir::triton::tools::getStrEnv("TRITON_INTEL_PREDICATED"));
3407-
if (!enablePredicated.has_value() || enablePredicated.value())
3404+
if (triton::tools::getBoolEnv("TRITON_INTEL_PREDICATED"))
34083405
rewriter.create<TritonGEN::PredicatedStoreOp>(
34093406
loc, addrElem, vecWord, b.i64_val(alignment), maskVal);
34103407
else

0 commit comments

Comments
 (0)