Skip to content

Commit 3339986

Browse files
Merge commit '82fec379583e72bd78d40d7cf3e980808669a428'
2 parents ea006f2 + 82fec37 commit 3339986

File tree

19 files changed

+862
-66
lines changed

19 files changed

+862
-66
lines changed

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
8787
mlir::registerTritonAMDGPUAccelerateMatmul();
8888
mlir::registerTritonAMDGPUOptimizeEpilogue();
8989
mlir::registerTritonAMDGPUReorderInstructions();
90+
mlir::registerTritonAMDGPUBlockPingpong();
9091
mlir::registerTritonAMDGPUStreamPipeline();
9192
mlir::registerTritonAMDGPUCanonicalizePointers();
9293
mlir::registerTritonAMDGPUConvertToBufferOps();

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
163163
LogicalResult getConvertBackwardSlice(
164164
Value root, SetVector<Value> &slice, Attribute rootEncoding,
165165
DenseMap<Value, Attribute> &layout,
166-
std::function<bool(Operation *)> stopPropagation = nullptr);
166+
std::function<bool(Operation *)> stopPropagation = nullptr,
167+
std::function<Value(Value, Attribute)> getExistingConversion = nullptr);
167168

168169
// Populate pattern to remove dead cycles in ForOp.
169170
void populateForOpDeadArgumentElimination(RewritePatternSet &patterns);

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
2929
"TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE",
3030
"TRITON_ENABLE_LLVM_DEBUG",
3131
"TRITON_HIP_STREAM_PREFETCH",
32+
"TRITON_HIP_USE_BLOCK_PINGPONG",
3233
"TRITON_LLVM_DEBUG_ONLY",
3334
"USE_IR_LOC",
3435
"NVPTX_ENABLE_DUMP",

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -116,17 +116,13 @@ class LayoutPropagation {
116116
class LayoutRematerialization {
117117
public:
118118
LayoutRematerialization(FuncOp F) : funcOp(F) {}
119+
119120
// Map the original value to the remat'ed one.
120121
void addRematValue(Value old, Attribute encoding, Value newV);
121-
bool hasRematValue(Value value, Attribute encoding) {
122-
return rematMapping.contains({value, encoding});
123-
}
124-
// Return the remat'ed value in the given encoding.
125-
Value getRematValue(Value value, Attribute encoding) {
126-
auto it = rematMapping.find({value, encoding});
127-
assert(it != rematMapping.end());
128-
return it->second;
129-
}
122+
// Get the remat'ed value in the given encoding, if one already exists and
123+
// is different then the layout conversion root.
124+
Value getRematValue(Value value, Attribute encoding, Value root) const;
125+
130126
void cleanup();
131127
void backwardRematerialization();
132128
void backwardRematerialization(ConvertLayoutOp convertOp);
@@ -137,6 +133,11 @@ class LayoutRematerialization {
137133
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
138134
ConvertLayoutOp convertOp);
139135

136+
LogicalResult getRematerializableSlice(
137+
Value root, Attribute rootEncoding, SetVector<Value> &slice,
138+
DenseMap<Value, Attribute> &layout,
139+
std::function<bool(Operation *)> stopPropagation = nullptr);
140+
140141
private:
141142
void updateRematMapping(SmallVector<std::tuple<Value, Value>> &values);
142143
// Existing tuples of (value, layout) that needs to be updated when recreating
@@ -157,6 +158,21 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
157158
mappedValues[old] = encoding;
158159
}
159160

161+
Value LayoutRematerialization::getRematValue(Value value, Attribute encoding,
162+
Value root) const {
163+
Value remat = rematMapping.lookup({value, encoding});
164+
if (!remat)
165+
return {};
166+
// If the remat'ed value is a conversion result, make sure it is different
167+
// than the root of the one we're looking at.
168+
if (auto cvt = remat.getDefiningOp<ConvertLayoutOp>()) {
169+
if (cvt.getSrc() == root)
170+
return {};
171+
}
172+
// This remat'ed value can be reused.
173+
return remat;
174+
}
175+
160176
// Remove unneeded values now that we are done with the rematMapping.
161177
void LayoutRematerialization::cleanup() {
162178
for (Operation *op : llvm::reverse(opToDelete))
@@ -766,8 +782,8 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
766782
auto layoutIt = layout.find(v);
767783
assert(layoutIt != layout.end());
768784
// If we already have a remat value for this value, use it.
769-
if (hasRematValue(v, layoutIt->second)) {
770-
mapping.map(v, getRematValue(v, layoutIt->second));
785+
if (Value remat = getRematValue(v, layoutIt->second, convertOp.getSrc())) {
786+
mapping.map(v, remat);
771787
valuesWithExistingRemat.insert(v);
772788
continue;
773789
}
@@ -928,12 +944,17 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
928944
rewriteSlice(slice, layout, convertOp, mapping);
929945
}
930946

931-
LogicalResult getRematerializableSlice(
947+
LogicalResult LayoutRematerialization::getRematerializableSlice(
932948
Value root, Attribute rootEncoding, SetVector<Value> &slice,
933949
DenseMap<Value, Attribute> &layout,
934-
std::function<bool(Operation *)> stopPropagation = nullptr) {
935-
LogicalResult result = getConvertBackwardSlice(root, slice, rootEncoding,
936-
layout, stopPropagation);
950+
std::function<bool(Operation *)> stopPropagation) {
951+
// Allow re-using existing conversions for a value.
952+
auto getExistingConversion = [&](Value value, Attribute encoding) -> Value {
953+
return getRematValue(value, encoding, root);
954+
};
955+
LogicalResult result =
956+
getConvertBackwardSlice(root, slice, rootEncoding, layout,
957+
stopPropagation, getExistingConversion);
937958
if (result.failed() || slice.empty())
938959
return failure();
939960

@@ -950,8 +971,14 @@ LogicalResult getRematerializableSlice(
950971
void LayoutRematerialization::backwardRematerialization() {
951972
// Go through each ConvertLayoutOp.
952973
SmallVector<ConvertLayoutOp> convertOps;
953-
funcOp.walk(
954-
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
974+
funcOp.walk([&](ConvertLayoutOp convertOp) {
975+
convertOps.push_back(convertOp);
976+
// Add existing layout conversions as rematerializations of themselves. This
977+
// enables rematerialization of other conversions to re-use existing
978+
// conversions. Importantly, don't add them to `mappedValues`.
979+
rematMapping.insert(
980+
{{convertOp.getSrc(), convertOp.getType().getEncoding()}, convertOp});
981+
});
955982
for (ConvertLayoutOp convertOp : convertOps) {
956983
backwardRematerialization(convertOp);
957984
}
@@ -976,14 +1003,13 @@ void LayoutRematerialization::backwardRematerialization(
9761003
// careful with the heuristics for both correctness and perf
9771004
if (isa<DotOperandEncodingAttr, LinearEncodingAttr>(targetType.getEncoding()))
9781005
return;
979-
Value oldV = convertOp->getOperand(0);
1006+
Value oldV = convertOp.getSrc();
9801007
LDBG("check backward remat with source " << oldV << " encoding "
9811008
<< targetType.getEncoding());
9821009
// Check to see if there are existing remat'ed values for the pair of oldValue
9831010
// and encoding.
984-
if (hasRematValue(oldV, targetType.getEncoding())) {
1011+
if (Value newV = getRematValue(oldV, targetType.getEncoding(), oldV)) {
9851012
// Replace it with the remat'ed value.
986-
Value newV = getRematValue(oldV, targetType.getEncoding());
9871013
convertOp.replaceAllUsesWith(newV);
9881014
opToDelete.insert(convertOp);
9891015
LDBG("found remat'ed value" << newV);

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -757,11 +757,11 @@ static bool isFreeConvert(Operation *op) {
757757
convertOp.getType());
758758
}
759759

760-
LogicalResult
761-
getConvertBackwardSlice(Value root, SetVector<Value> &slice,
762-
Attribute rootEncoding,
763-
DenseMap<Value, Attribute> &layout,
764-
std::function<bool(Operation *)> stopPropagation) {
760+
LogicalResult getConvertBackwardSlice(
761+
Value root, SetVector<Value> &slice, Attribute rootEncoding,
762+
DenseMap<Value, Attribute> &layout,
763+
std::function<bool(Operation *)> stopPropagation,
764+
std::function<Value(Value, Attribute)> getExistingConversion) {
765765
DenseSet<std::pair<Value, Attribute>> seen;
766766
SmallVector<std::pair<Value, Attribute>> queue;
767767

@@ -802,6 +802,12 @@ getConvertBackwardSlice(Value root, SetVector<Value> &slice,
802802

803803
continue;
804804
}
805+
Value existing;
806+
if (getExistingConversion &&
807+
(existing = getExistingConversion(currentValue, encoding))) {
808+
enqueue(existing, encoding);
809+
continue;
810+
}
805811
if (auto *definingOp = currentValue.getDefiningOp()) {
806812
// If the op has multiple results we need to update all results layout.
807813
for (Value result : definingOp->getResults()) {

python/test/unit/runtime/test_cache.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def kernel(X, i: tl.int32):
199199
kernel[(1, )](x, 8)
200200
kernel[(1, )](x, 16)
201201
kernel[(1, )](x, 17)
202-
assert len(kernel.cache[device]) == 3
202+
assert len(kernel.device_caches[device][0]) == 3
203203

204204

205205
GLOBAL_DEFAULT_ARG = 1
@@ -223,7 +223,7 @@ def kernel(X, i: tl.constexpr = GLOBAL_DEFAULT_ARG):
223223
assert x == torch.ones_like(x)
224224

225225
device = getattr(torch, device).current_device()
226-
assert len(kernel.cache[device]) == 1
226+
assert len(kernel.device_caches[device][0]) == 1
227227

228228

229229
GLOBAL_VAR: tl.constexpr = 1
@@ -416,13 +416,13 @@ def kernel_add(a, b, o, N: tl.constexpr):
416416
32,
417417
]
418418
device = getattr(torch, device).current_device()
419-
assert len(kernel_add.cache[device]) == 0
419+
assert len(kernel_add.device_caches[device][0]) == 0
420420
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
421-
assert len(kernel_add.cache[device]) == 1
421+
assert len(kernel_add.device_caches[device][0]) == 1
422422
kernel_add.warmup(*args, grid=(1, ))
423-
assert len(kernel_add.cache[device]) == 1
423+
assert len(kernel_add.device_caches[device][0]) == 1
424424
kernel_add.warmup(*args, grid=(1, ))
425-
assert len(kernel_add.cache[device]) == 1
425+
assert len(kernel_add.device_caches[device][0]) == 1
426426

427427

428428
def test_jit_debug(device) -> None:
@@ -433,12 +433,12 @@ def kernel(tmp):
433433

434434
device = getattr(torch, device).current_device()
435435
tmp = torch.tensor([1], dtype=torch.int32, device=device)
436-
assert len(kernel.cache[device]) == 0
436+
assert len(kernel.device_caches[device][0]) == 0
437437
kernel[(1, )](tmp, debug=False)
438-
assert len(kernel.cache[device]) == 1
438+
assert len(kernel.device_caches[device][0]) == 1
439439
kernel[(1, )](tmp, debug=True)
440-
assert len(kernel.cache[device]) == 2
441-
bins = list(kernel.cache[device].values())
440+
assert len(kernel.device_caches[device][0]) == 2
441+
bins = list(kernel.device_caches[device][0].values())
442442
assert bins[0].asm['ttir'] != bins[1].asm['ttir']
443443

444444

@@ -455,18 +455,18 @@ def kernel_add_device(a, b, o, N: tl.constexpr):
455455
add_fn(a, b, o, N)
456456

457457
device = getattr(torch, device).current_device()
458-
assert len(kernel_add_device.cache[device]) == 0
458+
assert len(kernel_add_device.device_caches[device][0]) == 0
459459
kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
460-
assert len(kernel_add_device.cache[device]) == 1
461-
bins = list(kernel_add_device.cache[device].values())
460+
assert len(kernel_add_device.device_caches[device][0]) == 1
461+
bins = list(kernel_add_device.device_caches[device][0].values())
462462
inline_ttir = bins[0].asm['ttir']
463463
add_fn.noinline = True
464464
add_fn.hash = None
465465
kernel_add_device.hash = None
466-
kernel_add_device.cache[device].clear()
466+
kernel_add_device.device_caches[device][0].clear()
467467
kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
468-
assert len(kernel_add_device.cache[device]) == 1
469-
bins = list(kernel_add_device.cache[device].values())
468+
assert len(kernel_add_device.device_caches[device][0]) == 1
469+
bins = list(kernel_add_device.device_caches[device][0].values())
470470
noinline_ttir = bins[0].asm['ttir']
471471
assert inline_ttir != noinline_ttir
472472

@@ -514,12 +514,12 @@ def cache_hook(*args, **kwargs):
514514

515515
# clear the cache
516516
shutil.rmtree(fresh_triton_cache)
517-
kernel_add.cache[device].clear()
517+
kernel_add.device_caches[device][0].clear()
518518

519519
# preload the kernel
520520
kernel_preload = kernel_add.preload(specialization_data)
521521
assert kernel_preload.hash == hash
522-
assert len(kernel_add.cache[device]) == 1
522+
assert len(kernel_add.device_caches[device][0]) == 1
523523

524524
# we should hit the cache and not compile anything
525525
counter = 0
@@ -532,7 +532,7 @@ def inc_counter(*args, **kwargs):
532532
final_kernel = kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, ))
533533
JITFunction.cache_hook = None
534534
assert counter == 0
535-
assert len(kernel_add.cache[device]) == 1
535+
assert len(kernel_add.device_caches[device][0]) == 1
536536
assert final_kernel.hash == hash
537537

538538
# test that we can't preload a mismatched kernel
@@ -572,7 +572,7 @@ def compiled_hook(*args, **kwargs):
572572
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, ))
573573
assert specialization_data is not None and specialization_data_compiled == specialization_data
574574
assert is_warmup is True
575-
assert key in kernel_add.cache[getattr(torch, device).current_device()]
575+
assert key in kernel_add.device_caches[getattr(torch, device).current_device()][0]
576576

577577

578578
@pytest.mark.skipif(reason="within_2g is a HIP specific optimization", condition=not is_hip())

python/triton/runtime/jit.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -544,47 +544,49 @@ def add_pre_run_hook(self, hook):
544544
assert callable(hook)
545545
self.pre_run_hooks.append(hook)
546546

547-
def create_binder(self, backend):
547+
def create_binder(self):
548548
"""
549549
Precompute as much as possible.
550550
"""
551551
from ..compiler import CompiledKernel, compile, ASTSource, make_backend
552+
target = driver.active.get_current_target()
553+
backend = make_backend(target)
552554
self.CompiledKernel = CompiledKernel
553555
self.compile = compile
554556
self.ASTSource = ASTSource
555-
self.make_backend = make_backend
556-
self.binder = create_function_from_signature(self.signature, self.params, backend)
557+
binder = create_function_from_signature(self.signature, self.params, backend)
557558
self.constexpr_indices = [i for (i, p) in enumerate(self.params) if p.is_constexpr]
558559
self.non_constexpr_indices = [i for (i, p) in enumerate(self.params) if not p.is_constexpr]
559560
self.specialised_indices = [
560561
i for (i, p) in enumerate(self.params) if (not p.do_not_specialize) and (not p.is_constexpr)
561562
]
563+
return [target, backend, binder]
562564

563565
def run(self, *args, grid, warmup, **kwargs):
564566
kwargs["debug"] = kwargs.get("debug", self.debug) or os.environ.get("TRITON_DEBUG", "0") == "1"
565567

566568
# parse options
567-
from ..compiler import make_backend
568569
device = driver.active.get_current_device()
569570
stream = driver.active.get_current_stream(device)
570-
target = driver.active.get_current_target()
571-
backend = make_backend(target)
572571

573572
# Execute pre run hooks with args and kwargs
574573
for hook in self.pre_run_hooks:
575574
hook(*args, **kwargs)
576575

577-
if self.binder is None:
578-
self.create_binder(backend)
579-
580-
bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs)
576+
# This is a length-4 list [kernel_cache, target, backend, binder]:
577+
device_cache = self.device_caches[device]
578+
if len(device_cache) == 1:
579+
device_cache[1:] = self.create_binder()
580+
bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = device_cache[3](*args, **kwargs)
581581

582582
# compute cache key
583583
key = ''.join(sig_and_spec) + str((constexpr_vals, excess_kwargs))
584-
kernel = self.cache[device].get(key, None)
584+
kernel = device_cache[0].get(key, None)
585585

586586
if kernel is None:
587587
# Kernel is not cached; we have to compile.
588+
target = device_cache[1]
589+
backend = device_cache[2]
588590
options = backend.parse_options(kwargs)
589591

590592
# deprecated arguments
@@ -625,7 +627,7 @@ def run(self, *args, grid, warmup, **kwargs):
625627
target=target,
626628
options=options.__dict__,
627629
)
628-
self.cache[device][key] = kernel
630+
device_cache[0][key] = kernel
629631
self._call_hook(key, signature, device, constants, options, configs, warmup, before=False)
630632

631633
# Check that used global values have not changed.
@@ -669,8 +671,6 @@ def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_o
669671
self.repr = lambda _: fn.__name__ if repr is None else repr(_)
670672
self.launch_metadata = launch_metadata
671673

672-
self.binder = None
673-
674674
self.params = []
675675
for i, param in enumerate(self.signature.parameters.values()):
676676
dns = i in do_not_specialize or param.name in do_not_specialize
@@ -681,7 +681,7 @@ def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_o
681681
self.src = textwrap.dedent(inspect.getsource(fn))
682682
self.src = self.src[re.search(r"^def\s+\w+\s*\(", self.src, re.MULTILINE).start():]
683683
# cache of just-in-time compiled kernels
684-
self.cache = defaultdict(dict)
684+
self.device_caches = defaultdict(lambda: [{}])
685685
self.hash = None
686686

687687
# Map of global variables used by the function and any functions it
@@ -750,7 +750,7 @@ def preload(self, specialization_data):
750750
}
751751
key = deserialized_obj['key']
752752
kernel = compile(src, None, options)
753-
self.cache[device][key] = kernel
753+
self.device_caches[device][0][key] = kernel
754754
return kernel
755755

756756
# we do not parse `src` in the constructor because

test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ set(TRITON_TEST_DEPENDS
1515
triton-opt
1616
triton-tensor-layout
1717
triton-translate
18+
triton-llvm-opt
1819
)
1920

2021
set(FILECHECK_PATH "${LLVM_LIBRARY_DIR}/../bin/FileCheck")

0 commit comments

Comments
 (0)