Skip to content

Commit 19de4d7

Browse files
Merge OpenAI Triton commit 27c5940 (#4940)
This PR change the Triton base from 05b2c18 to 27c5940 (Aug 11). Pass rate: 98.85%
2 parents 6a95135 + 99796f4 commit 19de4d7

File tree

38 files changed

+390
-252
lines changed

38 files changed

+390
-252
lines changed

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ class TargetInfoBase {
4848
virtual Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val,
4949
Value i) const = 0;
5050

51+
virtual Value permute(RewriterBase &rewriter, Location loc, Value a, Value b,
52+
Value selector) const = 0;
53+
5154
virtual Value programId(RewriterBase &rewriter, Location loc,
5255
ModuleOp moduleOp, ProgramIDDim axis) const = 0;
5356

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -571,8 +571,9 @@ SmallVector<Value> lowerLdSt(
571571
ArrayRef<Value> valsArray, // Input for store, output for load
572572
Type llvmElemTy, Value smemBase,
573573
std::function<Value(Value)> calcPaddedOffset, Value affineOffset,
574-
uint64_t maskSpanAffineOffset, RewriterBase &rewriter,
575-
const TargetInfoBase &targetInfo, std::optional<int> maybeMaxVecElems,
574+
uint64_t maskSpanAffineOffset, Value laneId, Value warpId,
575+
RewriterBase &rewriter, const TargetInfoBase &targetInfo,
576+
std::optional<int> maybeMaxVecElems,
576577
std::function<SmallVector<Value>(RewriterBase &, Location, ArrayRef<Value>,
577578
Value, int, VectorType)>
578579
lowerInst);

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -542,18 +542,20 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
542542
return unpackLLVector(loc, valsVec, rewriter);
543543
}
544544
};
545+
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
545546
return lowerLdSt(loc, ctx, cvt, valsArray, llvmElemTy, smemBase,
546-
calcPaddedOffset, affineOffset, maskSpanAffineOffset,
547-
rewriter, targetInfo, {}, emitLdSt);
547+
calcPaddedOffset, affineOffset, maskSpanAffineOffset, laneId,
548+
warpId, rewriter, targetInfo, {}, emitLdSt);
548549
}
549550

550551
SmallVector<Value> lowerLdSt(
551552
Location loc, MLIRContext *ctx, LinearLayout cvt,
552553
ArrayRef<Value> valsArray, // Input for store, output for load
553554
Type llvmElemTy, Value smemBase,
554555
std::function<Value(Value)> calcPaddedOffset, Value affineOffset,
555-
uint64_t maskSpanAffineOffset, RewriterBase &rewriter,
556-
const TargetInfoBase &targetInfo, std::optional<int> maybeMaxVecElems,
556+
uint64_t maskSpanAffineOffset, Value laneId, Value warpId,
557+
RewriterBase &rewriter, const TargetInfoBase &targetInfo,
558+
std::optional<int> maybeMaxVecElems,
557559
std::function<SmallVector<Value>(RewriterBase &, Location, ArrayRef<Value>,
558560
Value, int, VectorType)>
559561
lowerInst) {
@@ -599,7 +601,6 @@ SmallVector<Value> lowerLdSt(
599601
zerosLike(LinearLayout::identity1D(bitwidth / 8, kReg, kOffset));
600602
auto i8AddrLayout = i8Tile * addrLayout;
601603

602-
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
603604
auto regBaseI8 =
604605
applyLinearLayout(
605606
loc, rewriter, i8AddrLayout,
@@ -2022,16 +2023,17 @@ void finalizeTensorAtomicResults(Operation *op, RankedTensorType tensorTy,
20222023
};
20232024

20242025
auto noPaddingOffset = [](Value v) { return v; };
2026+
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
20252027
lowerLdSt(loc, ctx, dstLayout, resultVals, valueElemTy, smemBase,
20262028
/*calcPaddedOffset=*/noPaddingOffset, /*affineOffset=*/b.i32_val(0),
2027-
/*maskSpanAffineOffset=*/0, rewriter, targetInfo,
2029+
/*maskSpanAffineOffset=*/0, laneId, warpId, rewriter, targetInfo,
20282030
/*maybeMaxVecElems=*/{}, emitSt);
20292031
b.barrier();
20302032
resultVals = lowerLdSt(loc, ctx, dstLayout, resultVals, valueElemTy, smemBase,
20312033
/*calcPaddedOffset=*/noPaddingOffset,
20322034
/*affineOffset=*/b.i32_val(0),
2033-
/*maskSpanAffineOffset=*/0, rewriter, targetInfo,
2034-
/*maybeMaxVecElems=*/{}, emitLd);
2035+
/*maskSpanAffineOffset=*/0, laneId, warpId, rewriter,
2036+
targetInfo, /*maybeMaxVecElems=*/{}, emitLd);
20352037

20362038
// Create the result struct and replace the operation
20372039
Value resultStruct =

python/test/conftest.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,24 +108,24 @@ def fresh_cache():
108108

109109

110110
@pytest.fixture
111-
def fresh_knobs(monkeypatch):
111+
def fresh_knobs():
112112
from triton._internal_testing import _fresh_knobs_impl
113-
fresh_function, reset_function = _fresh_knobs_impl(monkeypatch)
113+
fresh_function, reset_function = _fresh_knobs_impl()
114114
try:
115115
yield fresh_function()
116116
finally:
117117
reset_function()
118118

119119

120120
@pytest.fixture
121-
def fresh_knobs_except_libraries(monkeypatch):
121+
def fresh_knobs_except_libraries():
122122
"""
123123
A variant of `fresh_knobs` that keeps library path
124124
information from the environment as these may be
125125
needed to successfully compile kernels.
126126
"""
127127
from triton._internal_testing import _fresh_knobs_impl
128-
fresh_function, reset_function = _fresh_knobs_impl(monkeypatch, skipped_attr={"build", "nvidia", "amd"})
128+
fresh_function, reset_function = _fresh_knobs_impl(skipped_attr={"build", "nvidia", "amd"})
129129
try:
130130
yield fresh_function()
131131
finally:

python/test/unit/cuda/test_tma_descriptor.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,12 @@ def example_load_store_kernel(X, Y, x_off, y_off, x_size, y_size):
5555
store_ragged(Y, y_off, y_size, [0, 0], data)
5656

5757

58-
@pytest.mark.parametrize("write_only", [False, True])
59-
@pytest.mark.parametrize("dtype", ["float16", "float32", "float64"])
60-
def test_ragged_tma(dtype, write_only):
58+
@pytest.mark.parametrize("dtype", [
59+
"bfloat16", "float16", "float32", "float64", # floating-point
60+
"int8", "int16", "int32", "int64", # signed integers
61+
"uint8", "uint16", "uint32", "uint64" # unsigned integers
62+
])
63+
def test_ragged_tma(dtype):
6164

6265
if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 9:
6366
pytest.skip("Test requires Hopper or Blackwell target.")
@@ -67,10 +70,10 @@ def test_ragged_tma(dtype, write_only):
6770

6871
src = torch.randn((1024, 80), dtype=torch.float32, device="cuda").to(dtype)
6972
ref = torch.randn((1024, 80), dtype=torch.float32, device="cuda").to(dtype)
70-
dst = 1.0 * ref
73+
dst = ref.clone()
7174

7275
X = create_ragged_descriptor(src, [32, 128])
73-
Y = create_ragged_descriptor(dst, [32, 128], write_only=write_only)
76+
Y = create_ragged_descriptor(dst, [32, 128])
7477

7578
x_off = 42
7679
y_off = 51

python/test/unit/language/test_core.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7325,6 +7325,26 @@ def simple(data, out):
73257325
assert amdgcn_gfx.group(1) == arch
73267326

73277327

7328+
def test_num_ctas_pre_sm90(device):
7329+
if not is_cuda() and not is_hip():
7330+
pytest.skip("Only supported on CUDA and HIP")
7331+
7332+
@triton.jit
7333+
def _kernel(src):
7334+
pass
7335+
7336+
src = torch.empty(1, device=device)
7337+
if is_cuda():
7338+
arch = "sm80"
7339+
msg = r"num_ctas > 1 requires NVIDIA SM90\+ \(Hopper\)"
7340+
else:
7341+
arch = "gfx942"
7342+
msg = r"num_ctas > 1 not supported for AMD GPUs"
7343+
7344+
with pytest.raises(ValueError, match=msg):
7345+
_kernel.warmup(src, grid=(1, ), num_ctas=2, arch=arch)
7346+
7347+
73287348
# -----------------------
73297349
# test propagate_nan
73307350
# -----------------------

python/test/unit/runtime/test_driver.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ def test_is_lazy():
1010
from importlib import reload
1111
reload(sys.modules["triton.runtime.driver"])
1212
reload(sys.modules["triton.runtime"])
13-
mod = sys.modules[triton.runtime.driver.__module__]
14-
assert isinstance(triton.runtime.driver.active, getattr(mod, "LazyProxy"))
15-
assert triton.runtime.driver.active._obj is None
13+
assert triton.runtime.driver._active is None
14+
assert triton.runtime.driver._default is None
15+
assert isinstance(triton.runtime.driver.active, getattr(triton.backends.driver, "DriverBase"))
16+
assert isinstance(triton.runtime.driver.default, getattr(triton.backends.driver, "DriverBase"))
1617
utils = triton.runtime.driver.active.utils # noqa: F841
17-
assert issubclass(triton.runtime.driver.active._obj.__class__, getattr(triton.backends.driver, "DriverBase"))
1818

1919

2020
def test_kernel_in_thread(device):

python/triton/_internal_testing.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,12 +204,14 @@ def unwrap_tensor(t: Union[torch.Tensor, triton.runtime.jit.TensorWrapper]) -> t
204204
return t
205205

206206

207-
def _fresh_knobs_impl(monkeypatch, skipped_attr: Optional[Set[str]] = None):
207+
def _fresh_knobs_impl(skipped_attr: Optional[Set[str]] = None):
208208
from triton import knobs
209209

210210
if skipped_attr is None:
211211
skipped_attr = set()
212212

213+
monkeypatch = pytest.MonkeyPatch()
214+
213215
knobs_map = {
214216
name: knobset
215217
for name, knobset in knobs.__dict__.items()
@@ -237,6 +239,9 @@ def fresh_function():
237239
def reset_function():
238240
for name, knobset in knobs_map.items():
239241
setattr(knobs, name, knobset)
242+
# `undo` should be placed before `del os.environ`
243+
# Otherwise, it may restore environment variables that monkeypatch deleted
244+
monkeypatch.undo()
240245
for k in env_to_unset:
241246
if k in os.environ:
242247
del os.environ[k]

python/triton/compiler/compiler.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -442,13 +442,14 @@ def __init__(self, src, metadata_group, hash):
442442
# (e.g., checking amount of shared memory on current device)
443443
self.module = None
444444
self.function = None
445+
self._run = None
445446

446447
def _init_handles(self):
447448
if self.module is not None:
448449
return
449450
device = driver.active.get_current_device()
450451
# create launcher
451-
self.run = driver.active.launcher_cls(self.src, self.metadata)
452+
self._run = driver.active.launcher_cls(self.src, self.metadata)
452453
# not enough shared memory to run the kernel
453454
max_shared = max_shared_mem(device)
454455
if self.metadata.shared > max_shared:
@@ -469,10 +470,14 @@ def _init_handles(self):
469470
if self.metadata.num_warps * warp_size > self.n_max_threads:
470471
raise OutOfResources(self.metadata.num_warps * warp_size, self.n_max_threads, "threads")
471472

472-
def __getattribute__(self, name):
473-
if name == 'run':
473+
@property
474+
def run(self):
475+
# it should be safe to do this as launch_metadata will
476+
# call _init_handles before running the kernel or it
477+
# was called manually or it was already initialized
478+
if self._run is None:
474479
self._init_handles()
475-
return super().__getattribute__(name)
480+
return self._run
476481

477482
def launch_metadata(self, grid, stream, *args):
478483
if knobs.runtime.launch_enter_hook is None:

python/triton/runtime/driver.py

Lines changed: 16 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
from ..backends import backends, DriverBase
44

5-
from typing import Any, Callable, Generic, TypeVar, Union
6-
75

86
def _create_driver() -> DriverBase:
97
active_drivers = [x.driver for x in backends.values() if x.driver.is_active()]
@@ -12,52 +10,29 @@ def _create_driver() -> DriverBase:
1210
return active_drivers[0]()
1311

1412

15-
T = TypeVar("T")
16-
17-
18-
class LazyProxy(Generic[T]):
19-
20-
def __init__(self, init_fn: Callable[[], T]) -> None:
21-
self._init_fn = init_fn
22-
self._obj: Union[T, None] = None
23-
24-
def _initialize_obj(self) -> T:
25-
if self._obj is None:
26-
self._obj = self._init_fn()
27-
return self._obj
28-
29-
def __getattr__(self, name) -> Any:
30-
return getattr(self._initialize_obj(), name)
31-
32-
def __setattr__(self, name: str, value: Any) -> None:
33-
if name in ["_init_fn", "_obj"]:
34-
super().__setattr__(name, value)
35-
else:
36-
setattr(self._initialize_obj(), name, value)
37-
38-
def __delattr__(self, name: str) -> None:
39-
delattr(self._initialize_obj(), name)
40-
41-
def __repr__(self) -> str:
42-
if self._obj is None:
43-
return f"<{self.__class__.__name__} for {self._init_fn} not yet initialized>"
44-
return repr(self._obj)
45-
46-
def __str__(self) -> str:
47-
return str(self._initialize_obj())
48-
49-
5013
class DriverConfig:
5114

5215
def __init__(self) -> None:
53-
self.default: LazyProxy[DriverBase] = LazyProxy(_create_driver)
54-
self.active: Union[LazyProxy[DriverBase], DriverBase] = self.default
16+
self._default: DriverBase | None = None
17+
self._active: DriverBase | None = None
18+
19+
@property
20+
def default(self) -> DriverBase:
21+
if self._default is None:
22+
self._default = _create_driver()
23+
return self._default
24+
25+
@property
26+
def active(self) -> DriverBase:
27+
if self._active is None:
28+
self._active = self.default
29+
return self._active
5530

5631
def set_active(self, driver: DriverBase) -> None:
57-
self.active = driver
32+
self._active = driver
5833

5934
def reset_active(self) -> None:
60-
self.active = self.default
35+
self._active = self.default
6136

6237

6338
driver = DriverConfig()

0 commit comments

Comments
 (0)