Skip to content

Commit 20c6797

Browse files
committed
feat: finished matmul_lt16_4n_k
1 parent 9b78662 commit 20c6797

File tree

14 files changed

+1660
-14
lines changed

14 files changed

+1660
-14
lines changed

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ set(KERNEL_FILES
8484
matmul_16mRest_4n_k.cpp
8585
matmul_16mRest_4nRest_k.h
8686
matmul_16mRest_4nRest_k.cpp
87+
matmul_lt16_4n_k.h
88+
matmul_lt16_4n_k.cpp
8789
)
8890

8991
set(ARM_INSTRUCTION_FILES
@@ -127,6 +129,7 @@ set(TEST_KERNELS
127129
matmul_16m_4n_k.test.cpp
128130
matmul_16m_lt4nRest_k.test.cpp
129131
matmul_16mRest_4nRest_k.test.cpp
132+
matmul_lt16_4n_k.test.cpp
130133
)
131134

132135
set(TEST_ARM_INSTRUCTION_FILES

src/main/arm_instructions/simd_fp/ld1.h

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ namespace mini_jit
1515

1616
const uint32_t ld1ImmediateRm = 0b11111;
1717

18+
enum class ld1DataTypes
19+
{
20+
// alias for the opcode used
21+
v8bit = 0b000,
22+
v16bit = 0b010,
23+
v32bit = 0b100,
24+
v64bit = 0b1100, // used as 0b100
25+
};
1826
enum class ld1Types
1927
{
2028
t8b,
@@ -201,6 +209,152 @@ namespace mini_jit
201209
return ld1;
202210
}
203211

212+
constexpr uint32_t ld1SingleStructures(const uint32_t Vt, const ld1DataTypes type, const uint32_t index, const uint32_t Xn)
213+
{
214+
release_assert((Vt & mask5) == Vt, "Vt is only allowed to have a size of 5 bit.");
215+
release_assert((Xn & mask5) == Xn, "Xn is only allowed to have a size of 5 bit.");
216+
217+
uint32_t q = 0xff; // should change
218+
uint32_t s = 0xff; // should change
219+
uint32_t size = 0xff; // should change
220+
switch (type)
221+
{
222+
case ld1DataTypes::v8bit:
223+
release_assert(index <= 15, "index is maximum is 15.");
224+
225+
q = (index >> 3) & mask1;
226+
s = (index >> 2) & mask1;
227+
size = index & mask2;
228+
break;
229+
case ld1DataTypes::v16bit:
230+
release_assert(index <= 7, "index is maximum is 7.");
231+
232+
q = (index >> 2) & mask1;
233+
s = (index >> 1) & mask1;
234+
size = 0b0;
235+
size |= (index & mask1) << 1;
236+
break;
237+
case ld1DataTypes::v32bit:
238+
release_assert(index <= 3, "index is maximum is 3.");
239+
240+
q = (index >> 1) & mask1;
241+
s = index & mask1;
242+
size = 0b00;
243+
break;
244+
case ld1DataTypes::v64bit:
245+
release_assert(index <= 1, "index is maximum is 1.");
246+
q = index & mask1;
247+
s = 0;
248+
size = 0b01;
249+
break;
250+
default:
251+
release_assert(false, "Undefined ld1 data type found.");
252+
break;
253+
}
254+
release_assert(q != 0xff, "Q should be retrieved from a type.");
255+
release_assert(s != 0xff, "S should be retrieved from a type.");
256+
release_assert(size != 0xff, "Size should be retrieved from a type.");
257+
release_assert((q & mask1) == q, "Q is only allowed to have a size of 1 bit.");
258+
release_assert((s & mask1) == s, "S is only allowed to have a size of 1 bit.");
259+
release_assert((size & mask2) == size, "Size is only allowed to have a size of 2 bit.");
260+
261+
uint32_t ld1 = 0;
262+
ld1 |= 0b0 << 31;
263+
ld1 |= (q & mask1) << 30;
264+
ld1 |= 0b00110101000000 << 16;
265+
ld1 |= (static_cast<uint32_t>(type) & mask3) << 13; // opcode
266+
ld1 |= (s & mask1) << 12;
267+
ld1 |= (size & mask2) << 10;
268+
ld1 |= (Xn & mask5) << 5;
269+
ld1 |= (Vt & mask5) << 0;
270+
return ld1;
271+
}
272+
273+
constexpr uint32_t ld1SingleStructuresPost(const uint32_t Vt, const ld1DataTypes type, const uint32_t index, const uint32_t Xn,
274+
const uint32_t imm, const uint32_t Xm)
275+
{
276+
release_assert((Vt & mask5) == Vt, "Vt is only allowed to have a size of 5 bit.");
277+
release_assert((Xn & mask5) == Xn, "Xn is only allowed to have a size of 5 bit.");
278+
release_assert((Xm & mask5) == Xm, "Xm is only allowed to have a size of 5 bit.");
279+
280+
if (Xm == ld1ImmediateRm)
281+
{
282+
switch (type)
283+
{
284+
case ld1DataTypes::v8bit:
285+
release_assert(imm == 1, "immm is only allowed to be 1 for the v8bit type.");
286+
break;
287+
case ld1DataTypes::v16bit:
288+
release_assert(imm == 2, "immm is only allowed to be 2 for the v8bit type.");
289+
break;
290+
case ld1DataTypes::v32bit:
291+
release_assert(imm == 4, "immm is only allowed to be 4 for the v8bit type.");
292+
break;
293+
case ld1DataTypes::v64bit:
294+
release_assert(imm == 8, "immm is only allowed to be 8 for the v8bit type.");
295+
break;
296+
default:
297+
release_assert(false, "Undefined ld1 data type found.");
298+
break;
299+
}
300+
}
301+
302+
uint32_t q = 0xff; // should change
303+
uint32_t s = 0xff; // should change
304+
uint32_t size = 0xff; // should change
305+
switch (type)
306+
{
307+
case ld1DataTypes::v8bit:
308+
release_assert(index <= 15, "index is maximum is 15.");
309+
310+
q = (index >> 3) & mask1;
311+
s = (index >> 2) & mask1;
312+
size = index & mask2;
313+
break;
314+
case ld1DataTypes::v16bit:
315+
release_assert(index <= 7, "index is maximum is 7.");
316+
317+
q = (index >> 2) & mask1;
318+
s = (index >> 1) & mask1;
319+
size = 0b0;
320+
size |= (index & mask1) << 1;
321+
break;
322+
case ld1DataTypes::v32bit:
323+
release_assert(index <= 3, "index is maximum is 3.");
324+
325+
q = (index >> 1) & mask1;
326+
s = index & mask1;
327+
size = 0b00;
328+
break;
329+
case ld1DataTypes::v64bit:
330+
release_assert(index <= 1, "index is maximum is 1.");
331+
q = index & mask1;
332+
s = 0;
333+
size = 0b01;
334+
break;
335+
default:
336+
release_assert(false, "Undefined ld1 data type found.");
337+
break;
338+
}
339+
release_assert(q != 0xff, "Q should be retrieved from a type.");
340+
release_assert(s != 0xff, "S should be retrieved from a type.");
341+
release_assert(size != 0xff, "Size should be retrieved from a type.");
342+
release_assert((q & mask1) == q, "Q is only allowed to have a size of 1 bit.");
343+
release_assert((s & mask1) == s, "S is only allowed to have a size of 1 bit.");
344+
release_assert((size & mask2) == size, "Size is only allowed to have a size of 2 bit.");
345+
346+
uint32_t ld1 = 0;
347+
ld1 |= 0b0 << 31;
348+
ld1 |= (q & mask1) << 30;
349+
ld1 |= 0b001101110 << 21;
350+
ld1 |= (Xm & mask5) << 16;
351+
ld1 |= (static_cast<uint32_t>(type) & mask3) << 13; // opcode
352+
ld1 |= (s & mask1) << 12;
353+
ld1 |= (size & mask2) << 10;
354+
ld1 |= (Xn & mask5) << 5;
355+
ld1 |= (Vt & mask5) << 0;
356+
return ld1;
357+
}
204358
} // namespace internal
205359

206360
template <typename T> constexpr uint32_t ld1(const VGeneral Vt, const T, const R64Bit Xn)
@@ -331,6 +485,74 @@ namespace mini_jit
331485
return internal::ld1MultipleStructuresPost(static_cast<uint32_t>(Vt), type, static_cast<uint32_t>(Xn), 0, static_cast<uint32_t>(Xm),
332486
4);
333487
}
488+
489+
constexpr uint32_t ld1(const V8Bit bt, const uint32_t index, const R64Bit Xn)
490+
{
491+
return internal::ld1SingleStructures(static_cast<uint32_t>(bt), internal::ld1DataTypes::v8bit, index, static_cast<uint32_t>(Xn));
492+
}
493+
494+
constexpr uint32_t ld1(const V16Bit bt, const uint32_t index, const R64Bit Xn)
495+
{
496+
return internal::ld1SingleStructures(static_cast<uint32_t>(bt), internal::ld1DataTypes::v16bit, index, static_cast<uint32_t>(Xn));
497+
}
498+
499+
constexpr uint32_t ld1(const V32Bit bt, const uint32_t index, const R64Bit Xn)
500+
{
501+
return internal::ld1SingleStructures(static_cast<uint32_t>(bt), internal::ld1DataTypes::v32bit, index, static_cast<uint32_t>(Xn));
502+
}
503+
504+
constexpr uint32_t ld1(const V64Bit bt, const uint32_t index, const R64Bit Xn)
505+
{
506+
return internal::ld1SingleStructures(static_cast<uint32_t>(bt), internal::ld1DataTypes::v64bit, index, static_cast<uint32_t>(Xn));
507+
}
508+
509+
constexpr uint32_t ld1Post(const V8Bit bt, const uint32_t index, const R64Bit Xn, const uint32_t imm)
510+
{
511+
return internal::ld1SingleStructuresPost(static_cast<uint32_t>(bt), internal::ld1DataTypes::v8bit, index, static_cast<uint32_t>(Xn),
512+
imm, internal::ld1ImmediateRm);
513+
}
514+
515+
constexpr uint32_t ld1Post(const V16Bit bt, const uint32_t index, const R64Bit Xn, const uint32_t imm)
516+
{
517+
return internal::ld1SingleStructuresPost(static_cast<uint32_t>(bt), internal::ld1DataTypes::v16bit, index, static_cast<uint32_t>(Xn),
518+
imm, internal::ld1ImmediateRm);
519+
}
520+
521+
constexpr uint32_t ld1Post(const V32Bit bt, const uint32_t index, const R64Bit Xn, const uint32_t imm)
522+
{
523+
return internal::ld1SingleStructuresPost(static_cast<uint32_t>(bt), internal::ld1DataTypes::v32bit, index, static_cast<uint32_t>(Xn),
524+
imm, internal::ld1ImmediateRm);
525+
}
526+
527+
constexpr uint32_t ld1Post(const V64Bit bt, const uint32_t index, const R64Bit Xn, const uint32_t imm)
528+
{
529+
return internal::ld1SingleStructuresPost(static_cast<uint32_t>(bt), internal::ld1DataTypes::v64bit, index, static_cast<uint32_t>(Xn),
530+
imm, internal::ld1ImmediateRm);
531+
}
532+
533+
constexpr uint32_t ld1Post(const V8Bit bt, const uint32_t index, const R64Bit Xn, const R64Bit Xm)
534+
{
535+
return internal::ld1SingleStructuresPost(static_cast<uint32_t>(bt), internal::ld1DataTypes::v8bit, index, static_cast<uint32_t>(Xn),
536+
0, static_cast<uint32_t>(Xm));
537+
}
538+
539+
constexpr uint32_t ld1Post(const V16Bit bt, const uint32_t index, const R64Bit Xn, const R64Bit Xm)
540+
{
541+
return internal::ld1SingleStructuresPost(static_cast<uint32_t>(bt), internal::ld1DataTypes::v16bit, index, static_cast<uint32_t>(Xn),
542+
0, static_cast<uint32_t>(Xm));
543+
}
544+
545+
constexpr uint32_t ld1Post(const V32Bit bt, const uint32_t index, const R64Bit Xn, const R64Bit Xm)
546+
{
547+
return internal::ld1SingleStructuresPost(static_cast<uint32_t>(bt), internal::ld1DataTypes::v32bit, index, static_cast<uint32_t>(Xn),
548+
0, static_cast<uint32_t>(Xm));
549+
}
550+
551+
constexpr uint32_t ld1Post(const V64Bit bt, const uint32_t index, const R64Bit Xn, const R64Bit Xm)
552+
{
553+
return internal::ld1SingleStructuresPost(static_cast<uint32_t>(bt), internal::ld1DataTypes::v64bit, index, static_cast<uint32_t>(Xn),
554+
0, static_cast<uint32_t>(Xm));
555+
}
334556
} // namespace arm_instructions
335557
} // namespace mini_jit
336558

0 commit comments

Comments
 (0)