Skip to content

Commit bc4dcdc

Browse files
committed
Begin adding SIMD support
1 parent ae43f64 commit bc4dcdc

File tree

8 files changed

+106
-51
lines changed

8 files changed

+106
-51
lines changed

tensorforge/backend/instructions/compute/primitives/amd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def mfma_emu_int8(writer: Writer, C, B, A, c, a, b):
374374
for x in const:
375375
constM *= x
376376
constI = [pow(constM // const[i], -1, const[i]) for i in range(len(const))]
377-
const2 = [(constM // const[i]) * constI[i] for i in range(len(const))]
377+
const2 = [float((constM // const[i]) * constI[i]) for i in range(len(const))]
378378
acc = len(const)
379379

380380
Aa = writer.varalloc()

tensorforge/backend/instructions/compute/primitives/intel.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11

2-
def dpas(C, B, A, rc, sd):
2+
def dpas(writer, C, B, A, rc, sd):
33
# cf. https://github.com/intel/intel-graphics-compiler/blob/master/documentation/visa/instructions/DPAS.md
44
# sd == depth == k * elemsIn32Bit
5-
# rc == m
6-
writer(f'asm("DPAS.tf32.tf32.{sd}.{rc} (16) %[D], %[C], %[B], %[A]" : [D]"=f"({C}) : [C]"f"({C}), [B]"d"({B}), [A]"d"({A}) :);')
5+
# rc == m [1,2,4,8]
6+
writer(f'tensorforge::intel_esimd::simd<tensorforge::TF32, 32> {A};')
7+
writer(f'tensorforge::intel_esimd::simd<tensorforge::TF32, 32> {B};')
8+
writer(f'tensorforge::intel_esimd::simd<float, 32> {C};')
9+
writer(f'{C} = tensorforge::intel_xmx::dpas<{sd}, {rc}, float>({C}, {B}, {A});')
710

8-
def matmul(writer, C, A, B, M, N, K, kx, threads, dtype, sparse, ctx):
11+
def fmadpp(writer, C, B, A, size, offset, lane):
12+
writer(f'{C}.select<{size}, 1>({offset}) += {A}[{lane}] * {B}.select<{size}, 1>({offset});')
913

14+
def load(writer, C):
15+
writer(f'{C}')
1016

17+
def matmul(writer, C, A, B, M, N, K, kx, threads, dtype, sparse, ctx):
1118
rc = 8
1219
sd = 8
13-
14-
dpas(C, A, B, rc, sd)
15-
16-
# TODO

tensorforge/backend/instructions/compute/primitives/nvidia.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,14 @@ def asmcall(self, writer, D, A, B, C):
222222
: {arggrp2(A, f"{typeidx}")}, {arggrp2(B, f"{typeidx}")}, {arggrp2(C, f"{typeid}")}
223223
);""")
224224

225+
def epilogue(self):
226+
pass
227+
225228
def generate(self, writer, context, A, B, C):
226-
Cstr = ','.join(f'{c}' for c in C)
227229
with writer.Scope():
228230
if self.mode == MMAMode.I8:
229-
raise NotImplementedError()
231+
232+
pass
230233
if self.mode == MMAMode.TF32:
231234
Atf32 = tfconvert(writer, A)
232235
Btf32 = tfconvert(writer, B)
@@ -297,11 +300,15 @@ def generate(self, writer, context, A, B, C):
297300
]
298301

299302
INSTRS = [
300-
MMAInstr(16,8,4,1,Datatype.F32,'mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32', MMAMode.TF32),
301-
MMAInstr(16,8,8,1,Datatype.F32,'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32', MMAMode.TF32),
302-
MMAInstr(8,8,4,1,Datatype.F64,'mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64', MMAMode.DIRECT),
303-
MMAInstr(16,8,4,1,Datatype.F64,'mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64', MMAMode.DIRECT),
304-
MMAInstr(16,8,8,1,Datatype.F64,'mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64', MMAMode.DIRECT),
303+
MMAInstr(16,8,4,1,Datatype.F32,'mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32', MMAMode.TF32), # SM_80
304+
MMAInstr(16,8,8,1,Datatype.F32,'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32', MMAMode.TF32), # SM_80
305+
MMAInstr(8,8,4,1,Datatype.F64,'mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64', MMAMode.DIRECT), # SM_80
306+
MMAInstr(16,8,4,1,Datatype.F64,'mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64', MMAMode.DIRECT), # SM_90
307+
MMAInstr(16,8,8,1,Datatype.F64,'mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64', MMAMode.DIRECT), # SM_90
308+
MMAInstr(16,8,16,1,Datatype.F64,'mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64', MMAMode.DIRECT), # SM_90
309+
MMAInstr(8,8,16,1,Datatype.F64,'mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32', MMAMode.I8), # SM_75
310+
MMAInstr(16,8,16,1,Datatype.F64,'mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32', MMAMode.I8), # SM_80
311+
MMAInstr(16,8,32,1,Datatype.F64,'mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32', MMAMode.I8), # SM_80
305312
]
306313

307314
def matmul(writer, C, A, B, M, N, K, kx, threads, dtype, sparse, ctx, shmptr, shmsize):

tensorforge/backend/instructions/memory/load.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
import math
33
from tensorforge.common.matrix.tensor import Tensor
44
from . import AbstractShrMemWrite, MemoryInstruction
5-
from tensorforge.backend.symbol import SymbolType, Symbol, DataView, LeadIndex
5+
from tensorforge.backend.symbol import Symbol, SymbolType, DataView, LeadIndex, write_loops, LeadLoop, Loop
66
from tensorforge.common.exceptions import InternalError
77
from tensorforge.backend.writer import Writer
88
from tensorforge.common.matrix.boundingbox import BoundingBox
99
from tensorforge.common.context import Context
10+
from tensorforge.backend.data_types import RegMemObject
1011
from typing import Union, List
1112

1213
# to find a number coprime to the number of shared memory banks
@@ -278,24 +279,24 @@ def __init__(self,
278279
num_threads: int):
279280
super(GlbToRegLoader, self).__init__(context)
280281

281-
if src.stype != SymbolType.Register:
282-
raise InternalError('store: operand `src` is not in reg mem')
282+
if dest.stype != SymbolType.Register:
283+
raise InternalError('store: operand `dest` is not in reg mem')
283284

284-
if not isinstance(src.obj, RegMemObject):
285-
raise InternalError(f'store: operand `src` is registers, instead: {type(src.obj)}')
285+
if not isinstance(dest.obj, RegMemObject):
286+
raise InternalError(f'store: operand `dest` is registers, instead: {type(dest.obj)}')
286287

287-
if dest.stype != SymbolType.Global:
288-
raise InternalError('store: operand `dest` is not in global memory.')
288+
if src.stype != SymbolType.Global:
289+
raise InternalError('store: operand `src` is not in global memory.')
289290

290-
if not isinstance(dest.obj, Tensor):
291-
raise InternalError('store: operand `dest` is not a matrix')
291+
if not isinstance(src.obj, Tensor):
292+
raise InternalError('store: operand `src` is not a matrix')
292293

293294
src.add_user(self)
294295
dest.add_user(self)
295296

296-
dest.data_view = DataView(shape=dest.obj.shape,
297+
dest.data_view = DataView(shape=src.obj.shape,
297298
permute=None,
298-
bbox=dest.obj.get_bbox())
299+
bbox=src.obj.get_bbox())
299300

300301
# if dest.data_view.get_dim_size(0) > src.data_view.get_dim_size(0):
301302
# raise InternalError('store: `src` and `dest` do not match in size aling dim `0`')
@@ -305,13 +306,12 @@ def __init__(self,
305306
self._num_threads: int = num_threads
306307
self._is_ready: bool = True
307308

308-
def gen_code(self, writer: Writer) -> None:
309+
def gen_code_inner(self, writer: Writer) -> None:
309310
writer.new_line()
310311
dest_view = self._dest.data_view
311312

312313
allow_nontemporal = len(self._src.get_user_list()) == 1
313314

314-
writer(f'// {self}')
315315
src_bbox = self._src.data_view.get_bbox()
316316

317317
loops = []
@@ -326,4 +326,4 @@ def inner(indices):
326326
write_loops(self._context, writer, loops, inner)
327327

328328
def __str__(self) -> str:
329-
return f'{self._dest.name} = store{{g>r}}({self._src.name});'
329+
return f'{self._dest.name} = load{{g>r}}({self._src.name});'

tensorforge/backend/symbol.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,9 @@ def write_nonlead(self):
145145
return f'{self._nonlead}'
146146

147147
def write(self, context: Context):
148-
if self._block > 1:
148+
if context.get_vm().get_lexic().simd_mode:
149+
return f'({self._nonlead} * {self._block})'
150+
elif self._block > 1:
149151
return f'(({context.get_vm().get_lexic().thread_idx_x} / {self._stride}) % {self._block}) + {self._nonlead} * {self._block}'
150152
elif self._block == 1:
151153
return f'{self._nonlead}'
@@ -443,11 +445,14 @@ def encode_values(self, pos, runIdx, writer, context: Context, variable, index:
443445
return wrote
444446

445447
def load_linear(self, writer, context: Context, variable, index):
446-
if self.stype == SymbolType.Register:
447-
access = f'{self.name}[{index // 32}]' # TODO
448+
if context.get_vm().get_lexic().simd_mode:
449+
writer(f'{context.get_vm().get_lexic().simd(self.get_fptype(), 16)} {variable}({index});')
448450
else:
449-
access = f'{self.name}[{index} + threadIdx.x]'
450-
writer(f'{self.get_fptype()} {variable} = {access};')
451+
if self.stype == SymbolType.Register:
452+
access = f'{self.name}[{index // 32}]' # TODO
453+
else:
454+
access = f'{self.name}[{index} + threadIdx.x]'
455+
writer(f'{self.get_fptype()} {variable} = {access};')
451456

452457
def load(self, writer, context: Context, variable, index: List[Union[str, int, Immediate, Variable, LeadIndex]], nontemp):
453458
if self.stype == SymbolType.Data or (not self.obj.is_dense() and not isinstance(self.obj.spp, BoundingBoxSPP)):
@@ -473,7 +478,9 @@ def load(self, writer, context: Context, variable, index: List[Union[str, int, I
473478
if self.stype == SymbolType.Register or self.stype == SymbolType.Scratch:
474479
assert len(self.lead_dims) == 1
475480
idx = index[self.lead_dims[0]]
476-
if not idx.is_thread_dependent():
481+
if isinstance(idx, (float, int, np.int32)) or not idx.is_thread_dependent():
482+
if isinstance(idx, (float, int, np.int32)):
483+
idx = Immediate(idx, Datatype.I32)
477484
# doesn't work
478485
if isinstance(idx, Variable):
479486
writevar = idx.write_nonlead()
@@ -490,7 +497,9 @@ def load(self, writer, context: Context, variable, index: List[Union[str, int, I
490497
access = pre_access
491498
else:
492499
access = pre_access
493-
if self.stype == SymbolType.Global:
500+
if context.get_vm().get_lexic().simd_mode:
501+
writer(f'{context.get_vm().get_lexic().simd(self.get_fptype(), 16)} {variable}({access});')
502+
elif self.stype == SymbolType.Global:
494503
writer(f'{self.get_fptype()} {variable};')
495504
writer(context.get_vm().get_lexic().glb_load(variable, access, nontemp))
496505
else:
@@ -502,19 +511,26 @@ def store(self, writer, context, variable, index: List[Union[str, int, Immediate
502511

503512
access = self.access(context, index)
504513

505-
if self.stype == SymbolType.Global:
506-
assign = context.get_vm().get_lexic().glb_store(access, variable, nontemp)
514+
if context.get_vm().get_lexic().simd_mode:
515+
if self.stype == SymbolType.Global:
516+
writer(f'{variable}.copy_to({access});')
517+
else:
518+
writer(f'{variable} = {access};')
507519
else:
508-
assign = f'{access} = {variable};'
509-
if self.stype == SymbolType.Register or self.stype == SymbolType.Scratch:
510-
assert len(self.lead_dims) == 1
511-
if isinstance(index[self.lead_dims[0]], LeadIndex):
512-
writer(assign)
520+
if self.stype == SymbolType.Global:
521+
assign = context.get_vm().get_lexic().glb_store(access, variable, nontemp)
513522
else:
514-
with writer.If(f'{context.get_vm().get_lexic().thread_idx_x} == {index[self.lead_dims[0]]}'):
523+
assign = f'{access} = {variable};'
524+
525+
if self.stype == SymbolType.Register or self.stype == SymbolType.Scratch:
526+
assert len(self.lead_dims) == 1
527+
if isinstance(index[self.lead_dims[0]], LeadIndex):
515528
writer(assign)
516-
else:
517-
writer(assign)
529+
else:
530+
with writer.If(f'{context.get_vm().get_lexic().thread_idx_x} == {index[self.lead_dims[0]]}'):
531+
writer(assign)
532+
else:
533+
writer(assign)
518534

519535
def add_user(self, user):
520536
self._users.append(user)

tensorforge/common/vm/lexic/lexic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def __init__(self, underlying_hardware):
1818
self.block_idx_x = None
1919
self.stream_type = None
2020
self.restrict_kw = None
21+
self.simd_mode = False
2122

2223
@abstractmethod
2324
def multifile(self):

tensorforge/common/vm/lexic/sycl_lexic.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def __init__(self, backend, underlying_hardware):
1818
self.stream_type = "sycl::queue"
1919
self.restrict_kw = "__restrict__"
2020

21+
self.simd_mode = self._underlying_hardware == 'intel' and self._backend == 'oneapi'
22+
2123
def multifile(self):
2224
return False
2325

@@ -45,7 +47,10 @@ def kernel_definition(self, file, kernel_bounds, base_name, params, precision=No
4547
localmem = None
4648

4749
if self._underlying_hardware == 'intel' and self._backend == 'oneapi':
48-
add_items = '[[intel::reqd_sub_group_size(16)]] [[intel::kernel_args_restrict]]'
50+
if self.simd_mode:
51+
add_items = '[[intel::sycl_explicit_simd]] [[intel::kernel_args_restrict]]'
52+
else:
53+
add_items = '[[intel::reqd_sub_group_size(16)]] [[intel::kernel_args_restrict]]'
4954
else:
5055
add_items = ''
5156

@@ -62,7 +67,10 @@ def sync_block(self):
6267
return "item.barrier()"
6368

6469
def sync_simd(self):
65-
return "item.barrier()" # TODO make better
70+
if self.simd_mode:
71+
return ""
72+
else:
73+
return "item.barrier()" # TODO make better
6674

6775
def sync_grid(self):
6876
raise NotImplementedError() # TODO
@@ -74,8 +82,11 @@ def get_sub_group_id(self, sub_group_size):
7482
def active_sub_group_mask(self):
7583
return f'item.get_sub_group()'
7684

77-
def broadcast(self, variable, lane, block=None, subblock=None):
78-
return f'group_broadcast(-1, {variable}, {lane})'
85+
def broadcast(self, variable, lane, block=None, subblock=1):
86+
if self.simd_mode:
87+
return f'{variable}.select<{block}, {subblock}>({lane})'
88+
else:
89+
return f'group_broadcast(-1, {variable}, {lane})'
7990

8091
def kernel_range_object(self, name, values):
8192
return f"sycl::range<3> {name} ({values})"
@@ -96,6 +107,9 @@ def get_headers(self):
96107
def get_fptype(self, fptype, length=1):
97108
return f'sycl::vec<{fptype}, {length}>'
98109

110+
def get_simd(self, fptype, size):
111+
return f'tensorforge::intel_esimd::simd<{fptype}, {size}>'
112+
99113
def get_operation(self, op: Operation, fptype, value1, value2):
100114
if op == Operation.COPY:
101115
return value1
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#pragma once
2+
3+
#include <sycl/ext/intel/esimd.hpp>
4+
#include <sycl/ext/intel/experimental/esimd/tfloat32.hpp>
5+
#include <sycl/sycl.hpp>
6+
7+
#include "base.h"
8+
9+
namespace tensorforge {
10+
namespace intel_esimd = sycl::ext::intel::esimd;
11+
namespace intel_xmx = iesimd::xmx;
12+
13+
using TF32 = sycl::ext::intel::experimental::esimd::tfloat32;
14+
} // namespace tensorforge

0 commit comments

Comments
 (0)