Skip to content

Commit 2674dbb

Browse files
Merge commit '27c5940a6ba4127be7bb0806fe750c3410ac393a'
2 parents 9f23f73 + 27c5940 commit 2674dbb

File tree

16 files changed

+229
-87
lines changed

16 files changed

+229
-87
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -571,8 +571,9 @@ SmallVector<Value> lowerLdSt(
571571
ArrayRef<Value> valsArray, // Input for store, output for load
572572
Type llvmElemTy, Value smemBase,
573573
std::function<Value(Value)> calcPaddedOffset, Value affineOffset,
574-
uint64_t maskSpanAffineOffset, RewriterBase &rewriter,
575-
const TargetInfoBase &targetInfo, std::optional<int> maybeMaxVecElems,
574+
uint64_t maskSpanAffineOffset, Value laneId, Value warpId,
575+
RewriterBase &rewriter, const TargetInfoBase &targetInfo,
576+
std::optional<int> maybeMaxVecElems,
576577
std::function<SmallVector<Value>(RewriterBase &, Location, ArrayRef<Value>,
577578
Value, int, VectorType)>
578579
lowerInst);

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 74 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -158,39 +158,54 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
158158
SmallVector<int32_t> matrix = flatten(A.getBases().begin()->second);
159159
assert(matrix.size() == nCol);
160160

161-
// We iterate the matrix following the diagonals
162-
// The idea here is that we want to generate code of the form:
163-
// \xor_i (x & mask_i) << s_i
164-
// where s_i may by positive or negative (left or right shift)
165-
// The hope here (and we see it in codegen) is that LLVM can turn
166-
// the xor into a sum and then the sum + LHS/RHS can be fused into a mad.lo
167-
// Get the i-th diagonal
168-
auto getMask = [&](int i) {
161+
// Row-wise popcount to detect rows that appear exactly once across columns.
162+
uint32_t rowsUnique = 0;
163+
{
164+
SmallVector<int> rowPopCnt(nRow, 0);
165+
for (int c = 0; c < nCol; ++c) {
166+
uint32_t colBits = matrix[c];
167+
for (int r = 0; r < nRow; ++r) {
168+
if (colBits & (1u << r))
169+
++rowPopCnt[r];
170+
}
171+
}
172+
for (int r = 0; r < nRow; ++r) {
173+
if (rowPopCnt[r] == 1)
174+
rowsUnique |= 1u << r;
175+
}
176+
}
177+
178+
// We iterate the matrix following the diagonals and build
179+
// (x & mask_i) << s_i terms. Prefer OR for diagonals whose rows are unique,
180+
// then XOR everything else. This tends to encourage mad.lo codegen.
181+
auto getMaskAndAllRowsUnique = [&](int i) -> std::pair<uint32_t, bool> {
169182
uint32_t mask = 0;
170183
int row = i < 0 ? -i : 0;
171184
int col = i < 0 ? 0 : i;
185+
bool allRowsUnique = true;
172186
while (row < nRow && col < nCol) {
173187
uint32_t bitValue = (matrix[col] >> row) & 1u;
174188
mask |= bitValue << col;
189+
allRowsUnique &= ((rowsUnique >> row) & 1u) == 1u;
175190
++row;
176191
++col;
177192
}
178-
return mask;
193+
return {mask, allRowsUnique};
179194
};
180195

181196
uint32_t explicitCols = 0;
182197

183198
{
184199
SmallVector<uint32_t> masks;
185200
for (int i = -nRow + 1; i < nCol; i++) {
186-
masks.push_back(getMask(i));
201+
masks.push_back(std::get<0>(getMaskAndAllRowsUnique(i)));
187202
}
188203
bool reachedFixedPoint = false;
189204
while (!reachedFixedPoint) {
190205
reachedFixedPoint = true;
191206
for (uint32_t m : masks) {
192207
uint32_t c = m & ~explicitCols;
193-
if ((c != 0) && ((c & (c - 1)) == 0)) {
208+
if (llvm::isPowerOf2_32(c)) {
194209
// found a single-element diagonal
195210
explicitCols |= c;
196211
reachedFixedPoint = false;
@@ -200,14 +215,21 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
200215
}
201216

202217
// handle any diagonals that have survived
203-
Value ret = b.i32_val(0);
218+
SmallVector<Value> ors;
219+
SmallVector<Value> xors;
204220
for (int i = -nRow + 1; i < nCol; i++) {
205-
auto mask = getMask(i) & ~explicitCols;
221+
auto [mask, allRowsUnique] = getMaskAndAllRowsUnique(i);
222+
mask &= ~explicitCols;
206223
if (mask == 0)
207224
continue;
208225
auto masked = b.and_(x, b.i32_val(mask));
209-
ret = b.xor_(ret, i >= 0 ? Value(b.lshr(masked, b.i32_val(i)))
210-
: Value(b.shl(masked, b.i32_val(-i))));
226+
auto shifted = i >= 0 ? Value(b.lshr(masked, b.i32_val(i)))
227+
: Value(b.shl(masked, b.i32_val(-i)));
228+
if (allRowsUnique) {
229+
ors.push_back(shifted);
230+
} else {
231+
xors.push_back(shifted);
232+
}
211233
}
212234

213235
// handle any explicit columns:
@@ -219,10 +241,35 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
219241
int32_t basis = matrix[i];
220242
if (basis == 0)
221243
continue;
222-
ret = b.xor_(ret, b.select(bit_is_zero, zero, b.i32_val(basis)));
244+
auto select = b.select(bit_is_zero, zero, b.i32_val(basis));
245+
if ((rowsUnique & basis) == basis) {
246+
ors.push_back(select);
247+
} else {
248+
xors.push_back(select);
249+
}
223250
}
224251
}
225-
return ret;
252+
253+
auto treeReduce = [&](SmallVector<Value> &terms,
254+
std::function<Value(Value, Value)> op) -> Value {
255+
if (terms.empty())
256+
return b.i32_val(0);
257+
while (terms.size() > 1) {
258+
SmallVector<Value> next;
259+
for (size_t i = 0; i + 1 < terms.size(); i += 2)
260+
next.push_back(op(terms[i], terms[i + 1]));
261+
if (terms.size() % 2 == 1)
262+
next.push_back(terms.back());
263+
terms = std::move(next);
264+
}
265+
return terms[0];
266+
};
267+
268+
auto orPart = treeReduce(
269+
ors, [&b](Value x, Value y) { return b.or_(x, y, /*disjoint=*/true); });
270+
auto xorPart =
271+
treeReduce(xors, [&b](Value x, Value y) { return b.xor_(x, y); });
272+
return b.or_(orPart, xorPart, /*disjoint=*/true);
226273
}
227274

228275
} // namespace triton::gpu
@@ -542,18 +589,20 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
542589
return unpackLLVector(loc, valsVec, rewriter);
543590
}
544591
};
592+
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
545593
return lowerLdSt(loc, ctx, cvt, valsArray, llvmElemTy, smemBase,
546-
calcPaddedOffset, affineOffset, maskSpanAffineOffset,
547-
rewriter, targetInfo, {}, emitLdSt);
594+
calcPaddedOffset, affineOffset, maskSpanAffineOffset, laneId,
595+
warpId, rewriter, targetInfo, {}, emitLdSt);
548596
}
549597

550598
SmallVector<Value> lowerLdSt(
551599
Location loc, MLIRContext *ctx, LinearLayout cvt,
552600
ArrayRef<Value> valsArray, // Input for store, output for load
553601
Type llvmElemTy, Value smemBase,
554602
std::function<Value(Value)> calcPaddedOffset, Value affineOffset,
555-
uint64_t maskSpanAffineOffset, RewriterBase &rewriter,
556-
const TargetInfoBase &targetInfo, std::optional<int> maybeMaxVecElems,
603+
uint64_t maskSpanAffineOffset, Value laneId, Value warpId,
604+
RewriterBase &rewriter, const TargetInfoBase &targetInfo,
605+
std::optional<int> maybeMaxVecElems,
557606
std::function<SmallVector<Value>(RewriterBase &, Location, ArrayRef<Value>,
558607
Value, int, VectorType)>
559608
lowerInst) {
@@ -599,7 +648,6 @@ SmallVector<Value> lowerLdSt(
599648
zerosLike(LinearLayout::identity1D(bitwidth / 8, kReg, kOffset));
600649
auto i8AddrLayout = i8Tile * addrLayout;
601650

602-
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
603651
auto regBaseI8 =
604652
applyLinearLayout(
605653
loc, rewriter, i8AddrLayout,
@@ -2022,16 +2070,17 @@ void finalizeTensorAtomicResults(Operation *op, RankedTensorType tensorTy,
20222070
};
20232071

20242072
auto noPaddingOffset = [](Value v) { return v; };
2073+
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
20252074
lowerLdSt(loc, ctx, dstLayout, resultVals, valueElemTy, smemBase,
20262075
/*calcPaddedOffset=*/noPaddingOffset, /*affineOffset=*/b.i32_val(0),
2027-
/*maskSpanAffineOffset=*/0, rewriter, targetInfo,
2076+
/*maskSpanAffineOffset=*/0, laneId, warpId, rewriter, targetInfo,
20282077
/*maybeMaxVecElems=*/{}, emitSt);
20292078
b.barrier();
20302079
resultVals = lowerLdSt(loc, ctx, dstLayout, resultVals, valueElemTy, smemBase,
20312080
/*calcPaddedOffset=*/noPaddingOffset,
20322081
/*affineOffset=*/b.i32_val(0),
2033-
/*maskSpanAffineOffset=*/0, rewriter, targetInfo,
2034-
/*maybeMaxVecElems=*/{}, emitLd);
2082+
/*maskSpanAffineOffset=*/0, laneId, warpId, rewriter,
2083+
targetInfo, /*maybeMaxVecElems=*/{}, emitLd);
20352084

20362085
// Create the result struct and replace the operation
20372086
Value resultStruct =

python/test/conftest.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,24 +108,24 @@ def fresh_cache():
108108

109109

110110
@pytest.fixture
111-
def fresh_knobs(monkeypatch):
111+
def fresh_knobs():
112112
from triton._internal_testing import _fresh_knobs_impl
113-
fresh_function, reset_function = _fresh_knobs_impl(monkeypatch)
113+
fresh_function, reset_function = _fresh_knobs_impl()
114114
try:
115115
yield fresh_function()
116116
finally:
117117
reset_function()
118118

119119

120120
@pytest.fixture
121-
def fresh_knobs_except_libraries(monkeypatch):
121+
def fresh_knobs_except_libraries():
122122
"""
123123
A variant of `fresh_knobs` that keeps library path
124124
information from the environment as these may be
125125
needed to successfully compile kernels.
126126
"""
127127
from triton._internal_testing import _fresh_knobs_impl
128-
fresh_function, reset_function = _fresh_knobs_impl(monkeypatch, skipped_attr={"build", "nvidia", "amd"})
128+
fresh_function, reset_function = _fresh_knobs_impl(skipped_attr={"build", "nvidia", "amd"})
129129
try:
130130
yield fresh_function()
131131
finally:

python/test/unit/language/test_core.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7325,6 +7325,26 @@ def simple(data, out):
73257325
assert amdgcn_gfx.group(1) == arch
73267326

73277327

7328+
def test_num_ctas_pre_sm90(device):
7329+
if not is_cuda() and not is_hip():
7330+
pytest.skip("Only supported on CUDA and HIP")
7331+
7332+
@triton.jit
7333+
def _kernel(src):
7334+
pass
7335+
7336+
src = torch.empty(1, device=device)
7337+
if is_cuda():
7338+
arch = "sm80"
7339+
msg = r"num_ctas > 1 requires NVIDIA SM90\+ \(Hopper\)"
7340+
else:
7341+
arch = "gfx942"
7342+
msg = r"num_ctas > 1 not supported for AMD GPUs"
7343+
7344+
with pytest.raises(ValueError, match=msg):
7345+
_kernel.warmup(src, grid=(1, ), num_ctas=2, arch=arch)
7346+
7347+
73287348
# -----------------------
73297349
# test propagate_nan
73307350
# -----------------------

python/test/unit/runtime/test_driver.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ def test_is_lazy():
1010
from importlib import reload
1111
reload(sys.modules["triton.runtime.driver"])
1212
reload(sys.modules["triton.runtime"])
13-
mod = sys.modules[triton.runtime.driver.__module__]
14-
assert isinstance(triton.runtime.driver.active, getattr(mod, "LazyProxy"))
15-
assert triton.runtime.driver.active._obj is None
13+
assert triton.runtime.driver._active is None
14+
assert triton.runtime.driver._default is None
15+
assert isinstance(triton.runtime.driver.active, getattr(triton.backends.driver, "DriverBase"))
16+
assert isinstance(triton.runtime.driver.default, getattr(triton.backends.driver, "DriverBase"))
1617
utils = triton.runtime.driver.active.utils # noqa: F841
17-
assert issubclass(triton.runtime.driver.active._obj.__class__, getattr(triton.backends.driver, "DriverBase"))
1818

1919

2020
def test_kernel_in_thread(device):

python/triton/_internal_testing.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,12 +204,14 @@ def unwrap_tensor(t: Union[torch.Tensor, triton.runtime.jit.TensorWrapper]) -> t
204204
return t
205205

206206

207-
def _fresh_knobs_impl(monkeypatch, skipped_attr: Optional[Set[str]] = None):
207+
def _fresh_knobs_impl(skipped_attr: Optional[Set[str]] = None):
208208
from triton import knobs
209209

210210
if skipped_attr is None:
211211
skipped_attr = set()
212212

213+
monkeypatch = pytest.MonkeyPatch()
214+
213215
knobs_map = {
214216
name: knobset
215217
for name, knobset in knobs.__dict__.items()
@@ -237,6 +239,9 @@ def fresh_function():
237239
def reset_function():
238240
for name, knobset in knobs_map.items():
239241
setattr(knobs, name, knobset)
242+
# `undo` should be placed before `del os.environ`
243+
# Otherwise, it may restore environment variables that monkeypatch deleted
244+
monkeypatch.undo()
240245
for k in env_to_unset:
241246
if k in os.environ:
242247
del os.environ[k]

python/triton/compiler/compiler.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -442,13 +442,14 @@ def __init__(self, src, metadata_group, hash):
442442
# (e.g., checking amount of shared memory on current device)
443443
self.module = None
444444
self.function = None
445+
self._run = None
445446

446447
def _init_handles(self):
447448
if self.module is not None:
448449
return
449450
device = driver.active.get_current_device()
450451
# create launcher
451-
self.run = driver.active.launcher_cls(self.src, self.metadata)
452+
self._run = driver.active.launcher_cls(self.src, self.metadata)
452453
# not enough shared memory to run the kernel
453454
max_shared = max_shared_mem(device)
454455
if self.metadata.shared > max_shared:
@@ -469,10 +470,14 @@ def _init_handles(self):
469470
if self.metadata.num_warps * warp_size > self.n_max_threads:
470471
raise OutOfResources(self.metadata.num_warps * warp_size, self.n_max_threads, "threads")
471472

472-
def __getattribute__(self, name):
473-
if name == 'run':
473+
@property
474+
def run(self):
475+
# it should be safe to do this as launch_metadata will
476+
# call _init_handles before running the kernel or it
477+
# was called manually or it was already initialized
478+
if self._run is None:
474479
self._init_handles()
475-
return super().__getattribute__(name)
480+
return self._run
476481

477482
def launch_metadata(self, grid, stream, *args):
478483
if knobs.runtime.launch_enter_hook is None:

0 commit comments

Comments
 (0)