Skip to content

Commit 685f544

Browse files
committed
More bugfixes; kernel chaining
1 parent 6b0ba0d commit 685f544

File tree

7 files changed

+253
-54
lines changed

7 files changed

+253
-54
lines changed

tensorforge/backend/instructions/builders/multilinear_builder.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from tensorforge.backend.scopes import Scopes
44
from tensorforge.backend.symbol import Symbol, SymbolType, SymbolView
55
from tensorforge.backend.instructions.allocate import RegisterAlloc
6-
from tensorforge.backend.instructions.memory.load import GlbToShrLoader
6+
from tensorforge.backend.instructions.memory.load import GlbToShrLoader, GlbToRegLoader
77
from tensorforge.backend.instructions.clear_registers import ClearRegisters
88
from tensorforge.backend.instructions.memory.store import StoreRegToGlb, StoreRegToShr, StoreRegToReg
99
from tensorforge.backend.instructions.sync_block import SyncThreads
@@ -19,8 +19,6 @@
1919

2020

2121
class MultilinearBuilder(AbstractBuilder):
22-
GemmClass = None
23-
2422
def __init__(self,
2523
context: Context,
2624
scopes: Scopes,
@@ -44,6 +42,7 @@ def __init__(self,
4442
self._dest_regs = None
4543

4644
self._use_registers_always = self._context.get_vm().get_hw_descr().vendor in ['amd']
45+
self._preload_registers = False
4746
self._deferred_stores = {}
4847
self._temporaries = {}
4948

@@ -71,7 +70,7 @@ def build(self, ops: List[Symbol], dest_obj: Tensor, descr: MultilinearDescr):
7170
# TODO: check if we always can allow a direct global memory load
7271
def _make_load_op(self, i):
7372

74-
prefer_broadcast = self._context.get_vm().get_hw_descr().vendor == 'amd'
73+
prefer_broadcast = self._context.get_vm().get_hw_descr().vendor in ['amd']
7574

7675
has_lead_dim = 0 in self._descr.target[i]
7776
transpose = self._descr.permute[i] != [j for j in range(len(self._descr.target[i]))]
@@ -113,8 +112,13 @@ def _make_load_op(self, i):
113112
self._loaders_cache[self._mem_regions[i]] = load_op
114113
self._instructions.append(load_op)
115114
else:
116-
# Note: operand will reside in glb. mem for gemm operation
117-
self._mem_regions[i] = self._ops[i]
115+
if self._preload_registers:
116+
self._mem_regions[i], load_op = self._make_loader_and_symbol_reg(self._ops[i].symbol, is_transpose=self._descr.permute[i])
117+
self._loaders_cache[self._mem_regions[i]] = load_op
118+
self._instructions.append(load_op)
119+
else:
120+
# Note: operand will reside in glb. mem for gemm operation
121+
self._mem_regions[i] = self._ops[i]
118122

119123
elif self._ops[i].symbol.stype == SymbolType.SharedMem or self._ops[i].symbol.stype == SymbolType.Register:
120124
if self._ops[i].symbol in self._loaders_cache.keys():
@@ -147,6 +151,32 @@ def _make_load_op(self, i):
147151
else:
148152
raise InternalError(f'gemm-builder: op{i} ({self._ops[i].symbol.name}) must be either in shr or glb mem, given: {self._ops[i].symbol.stype}')
149153

154+
def _make_loader_and_symbol_reg(self, operand, is_transpose) -> Tuple[Symbol, GlbToRegLoader]:
155+
regsize = 1
156+
threads = self._num_threads
157+
lead_dim = [0] # [t for t in self._descr.target[0] if t >= 0]
158+
159+
for d, dim in enumerate(operand.bbox.sizes()):
160+
if d not in lead_dim or threads == 0:
161+
regsize *= dim
162+
else:
163+
regsize *= (dim + threads - 1) // threads
164+
threads //= dim
165+
name = self._name_registers()
166+
regmem = RegMemObject(name, regsize)
167+
registers = Symbol(name=name, stype=SymbolType.Register, obj=regmem)
168+
registers.num_threads = self._num_threads
169+
registers.datatype = self._context.fp_type
170+
self._scopes.add_symbol(registers)
171+
registerAlloc = RegisterAlloc(self._context, registers, regsize, 0.0)
172+
self._instructions.append(registerAlloc)
173+
174+
load_op = GlbToRegLoader(context=self._context,
175+
dest=registers,
176+
src=operand,
177+
num_threads=self._num_threads)
178+
return SymbolView(registers), load_op
179+
150180
def _make_loader_and_symbol(self, operand, is_transpose) -> Tuple[Symbol, GlbToShrLoader]:
151181
shr_mem_region = Symbol(name=self._name_shr_reg(),
152182
stype=SymbolType.SharedMem,
@@ -170,7 +200,14 @@ def _alloc_register_array(self):
170200
regsize = 1
171201
threads = self._num_threads
172202
lead_dim = [0] # [t for t in self._descr.target[0] if t >= 0]
173-
for d, dim in enumerate(self._dest_obj.bbox.sizes()):
203+
204+
# TODO: shrink to enumerate(self._dest_obj.bbox.sizes())
205+
if self._add:
206+
sizes = self._get_target_symbol().data_view._bbox.sizes()
207+
else:
208+
sizes = self._dest_obj.bbox.sizes()
209+
210+
for d, dim in enumerate(sizes):
174211
if d not in lead_dim or threads == 0:
175212
regsize *= dim
176213
else:

tensorforge/backend/instructions/compute/primitives/amd.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -356,16 +356,16 @@ def reduction(writer: Writer, source, target, operation, blocks):
356356
var = tempvar
357357

358358
def cdna1(ctx):
359-
arch = ctx.get_vm().get_hw_descr().name
359+
arch = ctx.get_vm().get_hw_descr().model
360360
return arch in ('gfx908', 'gfx90a', 'gfx942', 'gfx950')
361361

362362
def cdna2(ctx):
363-
arch = ctx.get_vm().get_hw_descr().name
363+
arch = ctx.get_vm().get_hw_descr().model
364364
return arch in ('gfx90a', 'gfx942', 'gfx950')
365365

366366
def amdarch(ctx):
367-
archstr = ctx.get_vm().get_hw_descr().name
368-
return int(arch[3:], base=16)
367+
archstr = ctx.get_vm().get_hw_descr().model
368+
return int(archstr[3:], base=16)
369369

370370
def mfma_emu_int8(writer: Writer, C, B, A, c, a, b):
371371
# cf. the Ozaki II paper
@@ -416,10 +416,10 @@ def mfma_emu_bf16_f32(writer: Writer, C, B, A, c, a, b):
416416
writer(f'const bfloat16x4 {B3} = bfloat16x4({Br} - {B2});')
417417
writer(f'{C} = __builtin_amdgcn_mfma_f32_4x4x4bf16(get_native_vector({A1}), get_native_vector({B1}), {C}, {c}, {a}, {b});')
418418
writer(f'{C} = __builtin_amdgcn_mfma_f32_4x4x4bf16(get_native_vector({A1}), get_native_vector({B2}), {C}, {c}, {a}, {b});')
419-
writer(f'{C} = __builtin_amdgcn_mfma_f32_4x4x4bf16(get_native_vector({A1}), get_native_vector({B3}), {C}, {c}, {a}, {b});')
420419
writer(f'{C} = __builtin_amdgcn_mfma_f32_4x4x4bf16(get_native_vector({A2}), get_native_vector({B1}), {C}, {c}, {a}, {b});')
421-
writer(f'{C} = __builtin_amdgcn_mfma_f32_4x4x4bf16(get_native_vector({A2}), get_native_vector({B2}), {C}, {c}, {a}, {b});')
420+
writer(f'{C} = __builtin_amdgcn_mfma_f32_4x4x4bf16(get_native_vector({A1}), get_native_vector({B3}), {C}, {c}, {a}, {b});')
422421
writer(f'{C} = __builtin_amdgcn_mfma_f32_4x4x4bf16(get_native_vector({A3}), get_native_vector({B1}), {C}, {c}, {a}, {b});')
422+
writer(f'{C} = __builtin_amdgcn_mfma_f32_4x4x4bf16(get_native_vector({A2}), get_native_vector({B2}), {C}, {c}, {a}, {b});')
423423

424424
def matmul32(writer: Writer, C, B, A, M, N, K, kx, threads):
425425
with writer.AnonymousScope():
@@ -544,8 +544,39 @@ def hfma(writer: Writer, C, A, B, repeat, datatype, threads, ctx):
544544
if b is not None:
545545
func(writer, c, a, b, j)
546546

547+
def wmma3atom(threads):
548+
assert threads == 32
549+
550+
N = 16
551+
M = 16
552+
K = 16
553+
554+
for i in range(N):
555+
writer(f'const auto {a}_{i} = tensorforge::broadcast<32, 16, 0>({A}_{i});')
556+
for j in range(N):
557+
writer(f'const auto {b}_{j} = tensorforge::broadcast<32, 16, 0>({B}_{j});')
558+
559+
writer(f'tensorforge::transpose16x16({",".join(f"{b}_{i}" for i in range(N))});')
560+
561+
writer(f'VectorT<short, 16> {a}_p1;')
562+
writer(f'VectorT<short, 16> {a}_p2;')
563+
writer(f'VectorT<short, 16> {a}_p3;')
564+
writer(f'VectorT<short, 16> {b}_p1;')
565+
writer(f'VectorT<short, 16> {b}_p2;')
566+
writer(f'VectorT<short, 16> {b}_p3;')
567+
568+
writer(f'VectorT<float, 8> {c}{"{}"};')
569+
writer(f'{c} = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({a}_p1, {b}_p1, {c});')
570+
writer(f'{c} = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({a}_p2, {b}_p1, {c});')
571+
writer(f'{c} = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({a}_p1, {b}_p2, {c});')
572+
writer(f'{c} = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({a}_p3, {b}_p1, {c});')
573+
writer(f'{c} = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({a}_p1, {b}_p3, {c});')
574+
writer(f'{c} = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({a}_p2, {b}_p2, {c});')
575+
576+
# TODO: gfx1200, f'__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12'
577+
547578
def matmul(writer, C, A, B, M, N, K, kx, threads, dtype, sparse, ctx):
548-
if cdna1(ctx) and not sparse and dtype == Datatype.F32:
579+
if amdarch(ctx) >= 0x908 and amdarch(ctx) < 0x1000 and not sparse and dtype == Datatype.F32:
549580
matmul32(writer, C, A, B, M, N, K, kx, threads)
550581
else:
551582
ab = []

tensorforge/backend/symbol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def encode_values(self, pos, runIdx, writer, context: Context, variable, index:
444444

445445
def load_linear(self, writer, context: Context, variable, index):
446446
if self.stype == SymbolType.Register:
447-
access = f'{self.name}[{index}]'
447+
access = f'{self.name}[{index // 32}]' # TODO
448448
else:
449449
access = f'{self.name}[{index} + threadIdx.x]'
450450
writer(f'{self.get_fptype()} {variable} = {access};')

tensorforge/frontend/yateto.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from tensorforge.common.matrix.spp import FullSPP, BoundingBoxSPP, ListSPP
88
from tensorforge.common.matrix.boundingbox import BoundingBox as BBox
99
from tensorforge.generators.generator import Generator as TensorForgeGenerator
10-
from tensorforge.generators.descriptions import MultilinearDescr, ElementwiseDescr, GridBarrierDescr, GridFenceDescr
10+
from tensorforge.generators.descriptions import MultilinearDescr, ElementwiseDescr, GridBarrierDescr, GridFenceDescr, RegionDescription
1111

1212
from tensorforge.ir.data.variable import TensorView, TensorAlloc
1313
from tensorforge.ir.data.variable import TensorData
@@ -27,6 +27,9 @@ def __init__(self, arch):
2727
self._ir_list = []
2828
self._tensor_list = {}
2929

30+
# TODO: maybe remove again
31+
self._prefix = ""
32+
3033
def add_operation(self, dest, ops, target, permute, add):
3134
self._cache_matrices(dest, ops, target, permute)
3235
can_be_aligned = self._can_be_aligned(dest, ops, target, permute)
@@ -96,9 +99,9 @@ def get_tensor(self, op, can_be_aligned, dims):
9699
if isinstance(op, (float, int)):
97100
return SubTensor(tensor = Tensor([], Addressing.SCALAR, data = [op]))
98101
elif self.is_scalar(op):
99-
return SubTensor(self._cache[op.name()])
102+
return SubTensor(self._cache[f'{self._prefix}{op.name()}'])
100103
else:
101-
tensor = self._cache[op.name]
104+
tensor = self._cache[f'{self._prefix}{op.name}']
102105
currentPreShape = BBox([s for s, _ in op.eqspp.nnzbounds()], [e+1 for _, e in op.eqspp.nnzbounds()])
103106

104107
tml = op.memoryLayout
@@ -136,10 +139,10 @@ def assigner(pretensor):
136139
if self.is_scalar(pretensor):
137140
self.make_tensor(pretensor, False, None)
138141
indicesIndexed[pretensor.name()] = []
139-
subTensor = SubTensor(self._cache[pretensor.name()], BBox([], []))
142+
subTensor = SubTensor(self._cache[f'{self._prefix}{pretensor.name()}'], BBox([], []))
140143
else:
141144
bbox = BBox([s for s, _ in pretensor.eqspp().nnzbounds()], [e+1 for _, e in pretensor.eqspp().nnzbounds()])
142-
subTensor = SubTensor(self._cache[pretensor.name()], bbox)
145+
subTensor = SubTensor(self._cache[f'{self._prefix}{pretensor.name()}'], bbox)
143146
return subTensor, indicesIndexed[pretensor.name()]
144147

145148
for statement in statements:
@@ -203,13 +206,17 @@ def make_tensor(self, op, can_be_aligned, dims):
203206
entry = self._get_tensorforge_matrix(op)
204207
entry_name = op.name
205208

209+
entry_name = f'{self._prefix}{entry_name}'
210+
206211
if not (entry_name in self._cache and entry.is_same(self._cache[entry_name])):
207212
self._cache[entry_name] = entry
208213

209214
def tensor_ref(self, d):
210215
name = d['name']
211216
eqspp = d['spp']
212217

218+
name = f'{self._prefix}{name}'
219+
213220
assert(name in self._cache)
214221

215222
return SubTensor(self._cache[name], self._cache[name].bbox)
@@ -226,6 +233,8 @@ def tensor_ref_new(self, d):
226233

227234
def add_tensor(self, d):
228235
name = d['name']
236+
name = f'{self._prefix}{name}'
237+
229238
datatype = Datatype.ytt2enum(d['datatype'])
230239

231240
datatype_new = BaseDatatype.ytt2enum(d['datatype'])
@@ -276,16 +285,17 @@ def _cache_matrices(self, dest, ops, target, permute):
276285

277286
if dest.is_temporary: # (dest is never a scalar---for the time being)
278287
self.make_tensor(dest, can_be_aligned, [i for i in range(len(dest.indices))])
279-
self._tmp_matrices[dest.name] = self._cache[dest.name]
288+
self._tmp_matrices[f'{self._prefix}{dest.name}'] = self._cache[f'{self._prefix}{dest.name}']
280289
else:
281290
self.make_tensor(dest, can_be_aligned, [i for i in range(len(dest.indices))])
282291

283292

284293

285294
def _add_scalar(self, scalar):
286-
tensor = Tensor([], Addressing.SCALAR, alias=scalar.name(), datatype=self._datatype(scalar.datatype))
287-
self._tmp_matrices[scalar.name()] = tensor # SubTensor(tensor, tensor.bbox)
288-
return self._tmp_matrices[scalar.name()]
295+
name = f'{self._prefix}{scalar.name()}'
296+
tensor = Tensor([], Addressing.SCALAR, alias=name, datatype=self._datatype(scalar.datatype))
297+
self._tmp_matrices[name] = tensor # SubTensor(tensor, tensor.bbox)
298+
return self._tmp_matrices[name]
289299

290300
def deduce_addresing(self, term):
291301
if term.is_compute_constant:
@@ -323,7 +333,7 @@ def _get_tensorforge_matrix(self, tensor):
323333
return yi.gen_matrix(shape,
324334
bboxrange,
325335
addressing=addr_mode,
326-
name=tensor.name,
336+
name=f'{self._prefix}{tensor.name}',
327337
is_tmp=tensor.is_temporary,
328338
permute=None,
329339
pattern=pattern,
@@ -345,28 +355,35 @@ def _gen_call_site(self, generator):
345355
if matrix.is_tmp or matrix.addressing == Addressing.NONE:
346356
offset_name_map[name] = '0'
347357
else:
348-
offset_name_map[name] = f'extraOffset_{name}'
358+
parts = name.split('.')
359+
assert len(parts) <= 2
360+
varname = f'extraOffset_{parts[-1]}'
361+
if len(parts) == 2:
362+
offset_name_map[name] = f'{parts[0]}.{varname}'
363+
else:
364+
offset_name_map[name] = varname
349365

350366
return generator.generate_call_site(mat_name_map,
351-
offset_name_map,
352-
'numElements',
353-
'flags',
354-
'streamPtr')
367+
offset_name_map)
355368

356369
def _append_operation(self, op):
357370
if isinstance(op, (float, int)):
358371
return Tensor([], Addressing.SCALAR, data = op)
359372
elif self.is_scalar(op):
360-
return self._cache[op.name()]
373+
return self._cache[f'{self._prefix}{op.name()}']
361374
else:
362-
return self._cache[op.name]
375+
return self._cache[f'{self._prefix}{op.name}']
363376

364377
def switch_region(self, barrier):
365378
if barrier:
366379
self._descr_list += [GridBarrierDescr()]
367380
else:
368381
self._descr_list += [GridFenceDescr()]
369382

383+
def set_region_name(self, name):
384+
self._prefix = f"{name}."
385+
self._descr_list += [RegionDescription(name)]
386+
370387
class TensorForgeWriter:
371388
def __init__(self, tensorforge_generator, headers):
372389
self._headers = list(headers) + list(tensorforge_generator.get_helper_headers())
@@ -410,6 +427,9 @@ def region_switch(self, barrier):
410427
self.generator.switch_region(barrier)
411428
return 0
412429

430+
def set_region_name(self, name):
431+
self.generator.set_region_name(name)
432+
413433
def add_operation(self, description):
414434
return self.generator.add_operation_new(description)
415435

tensorforge/generators/descriptions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,16 @@ def __str__(self):
189189

190190
def trueBarrier(self):
191191
return True
192+
193+
class RegionDescription(OperationDescription):
194+
def __init__(self, name):
195+
self.name = name
196+
197+
def matrix_list(self):
198+
return []
199+
200+
def get_num_threads(self, ctx):
201+
return 32, 32
202+
203+
def __str__(self):
204+
return f'region "{self.name}"'

0 commit comments

Comments
 (0)