Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion external/asmjit
Submodule asmjit updated 211 files
64 changes: 32 additions & 32 deletions src/EmbeddingSpMDM.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,28 +309,28 @@ GenEmbeddingSpMDMLookup<
frame.init(func);

if constexpr (instSet == inst_set_t::avx2) {
frame.setDirtyRegs(
frame.set_dirty_regs(
asmjit::RegGroup::kVec,
asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
asmjit::Support::bit_mask<int>(0, 1, 2, 3, 4, 5, 6, 7) |
asmjit::Support::bit_mask<int>(8, 9, 10, 11, 12, 13, 14, 15));
} else {
frame.setDirtyRegs(
frame.set_dirty_regs(
asmjit::RegGroup::kVec,
asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) |
asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) |
asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31));
asmjit::Support::bit_mask<int>(0, 1, 2, 3, 4, 5, 6, 7) |
asmjit::Support::bit_mask<int>(8, 9, 10, 11, 12, 13, 14, 15) |
asmjit::Support::bit_mask<int>(16, 17, 18, 19, 20, 21, 22, 23) |
asmjit::Support::bit_mask<int>(24, 25, 26, 27, 28, 29, 30, 31));
}

frame.setDirtyRegs(
frame.set_dirty_regs(
asmjit::RegGroup::kGp,
reg_id == 15
? asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)
: asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
? asmjit::Support::bit_mask<int>(8, 9, 10, 11, 12, 13, 14, 15)
: asmjit::Support::bit_mask<int>(8, 9, 10, 11, 12, 13, 14));

asmjit::FuncArgsAssignment args(&func);
if constexpr (ROWWISE_SPARSE) {
args.assignAll(
args.assign_all(
output_size,
index_size,
data_size,
Expand All @@ -342,7 +342,7 @@ GenEmbeddingSpMDMLookup<
compressed_indices_table,
scratchReg1_);
} else {
args.assignAll(
args.assign_all(
output_size,
index_size,
data_size,
Expand All @@ -354,11 +354,11 @@ GenEmbeddingSpMDMLookup<
scratchReg1_);
}

args.updateFuncFrame(frame);
args.update_func_frame(frame);
frame.finalize();

a->emitProlog(frame);
a->emitArgsAssignment(frame, args);
a->emit_prolog(frame);
a->emit_args_assignment(frame, args);

constexpr int vlen = simd_info<instSet>::WIDTH_32BIT_ELEMS;
constexpr int NUM_VEC_REG = simd_info<instSet>::NUM_VEC_REGS;
Expand Down Expand Up @@ -451,19 +451,19 @@ GenEmbeddingSpMDMLookup<
a->lea(
index_size, x86::ptr(indices, index_size, areIndices64b ? 3 : 2));

asmjit::Label exit = a->newLabel();
asmjit::Label error = a->newLabel();
asmjit::Label LoopRangeIndexBegin = a->newLabel();
asmjit::Label LoopRangeIndexEnd = a->newLabel();
asmjit::Label exit = a->new_label();
asmjit::Label error = a->new_label();
asmjit::Label LoopRangeIndexBegin = a->new_label();
asmjit::Label LoopRangeIndexEnd = a->new_label();

// rangeIndex loop begins (iterate output_size times)
a->bind(LoopRangeIndexBegin);
a->dec(output_size);
a->jl(LoopRangeIndexEnd);

if (normalize_by_lengths) {
asmjit::Label IfLengthsBegin = a->newLabel();
asmjit::Label IfLengthsEnd = a->newLabel();
asmjit::Label IfLengthsBegin = a->new_label();
asmjit::Label IfLengthsEnd = a->new_label();
a->bind(IfLengthsBegin);
if (use_offsets) {
a->mov(lengths_R_, x86::dword_ptr(lengths, sizeof(offsetType)));
Expand Down Expand Up @@ -520,9 +520,9 @@ GenEmbeddingSpMDMLookup<
a->cmp(scratchReg1_, index_size);
a->jg(error);

asmjit::Label LoopDataIndexBegin = a->newLabel();
asmjit::Label LoopDataIndexEnd = a->newLabel();
asmjit::Label ValidIndexLabel = a->newLabel();
asmjit::Label LoopDataIndexBegin = a->new_label();
asmjit::Label LoopDataIndexEnd = a->new_label();
asmjit::Label ValidIndexLabel = a->new_label();

// dataIndex loop begins (iterate lengths_R_ times)
a->bind(LoopDataIndexBegin);
Expand Down Expand Up @@ -569,8 +569,8 @@ GenEmbeddingSpMDMLookup<
int fused_block_size = input_stride * sizeof(inType);

if (pref_dist) {
asmjit::Label pref_dist_reset_start = a->newLabel();
asmjit::Label pref_dist_reset_end = a->newLabel();
asmjit::Label pref_dist_reset_start = a->new_label();
asmjit::Label pref_dist_reset_end = a->new_label();
// out of bound handling for prefetch
a->lea(
scratchReg2_, x86::ptr(indices, pref_dist * sizeof(indxType)));
Expand Down Expand Up @@ -601,8 +601,8 @@ GenEmbeddingSpMDMLookup<
a->bind(pref_dist_reset_end);
if constexpr (ROWWISE_SPARSE) {
asmjit::Label rowwise_sparse_pref_corner_case_begin =
a->newLabel();
asmjit::Label rowwise_sparse_pref_corner_case_end = a->newLabel();
a->new_label();
asmjit::Label rowwise_sparse_pref_corner_case_end = a->new_label();
a->cmp(scratchReg2_, data_size);
a->jae(rowwise_sparse_pref_corner_case_begin);

Expand Down Expand Up @@ -934,7 +934,7 @@ GenEmbeddingSpMDMLookup<
a->lea(x86::rsp, x86::ymmword_ptr(x86::rsp, vlen * sizeof(int32_t)));
}

a->emitEpilog(frame);
a->emit_epilog(frame);

// jit_fused8bitembedding_kernel fn;
typename ReturnFunctionSignature<
Expand All @@ -943,13 +943,13 @@ GenEmbeddingSpMDMLookup<
offsetType,
outType,
ROWWISE_SPARSE>::jit_embedding_kernel fn;
asmjit::Error err = 0;
asmjit::Error err = asmjit::Error::kOk;
{
std::unique_lock<std::mutex> lock(rtMutex_);
err = runtime().add(&fn, &code);
}

if (err) {
if (err != asmjit::Error::kOk) {
std::cout << "Error: in fn add" << '\n';
return nullptr;
}
Expand Down
58 changes: 29 additions & 29 deletions src/EmbeddingSpMDMNBit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,22 +283,22 @@ GenEmbeddingSpMDMNBitLookup<
asmjit::FuncFrame frame;
frame.init(func);

frame.setDirtyRegs(
frame.set_dirty_regs(
asmjit::RegGroup::kVec,
asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) |
asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) |
asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31));
asmjit::Support::bit_mask<int>(0, 1, 2, 3, 4, 5, 6, 7) |
asmjit::Support::bit_mask<int>(8, 9, 10, 11, 12, 13, 14, 15) |
asmjit::Support::bit_mask<int>(16, 17, 18, 19, 20, 21, 22, 23) |
asmjit::Support::bit_mask<int>(24, 25, 26, 27, 28, 29, 30, 31));

frame.setDirtyRegs(
frame.set_dirty_regs(
asmjit::RegGroup::kGp,
reg_id == 15
? asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)
: asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
? asmjit::Support::bit_mask<int>(8, 9, 10, 11, 12, 13, 14, 15)
: asmjit::Support::bit_mask<int>(8, 9, 10, 11, 12, 13, 14));

asmjit::FuncArgsAssignment args(&func);
if constexpr (ROWWISE_SPARSE) {
args.assignAll(
args.assign_all(
output_size,
index_size,
data_size,
Expand All @@ -310,7 +310,7 @@ GenEmbeddingSpMDMNBitLookup<
compressed_indices_table,
scratchReg1_);
} else {
args.assignAll(
args.assign_all(
output_size,
index_size,
data_size,
Expand All @@ -322,11 +322,11 @@ GenEmbeddingSpMDMNBitLookup<
scratchReg1_);
}

args.updateFuncFrame(frame);
args.update_func_frame(frame);
frame.finalize();

a->emitProlog(frame);
a->emitArgsAssignment(frame, args);
a->emit_prolog(frame);
a->emit_args_assignment(frame, args);

constexpr int vlen = simd_info<instSet>::WIDTH_32BIT_ELEMS;
constexpr int NUM_VEC_REG = simd_info<instSet>::NUM_VEC_REGS;
Expand Down Expand Up @@ -480,19 +480,19 @@ GenEmbeddingSpMDMNBitLookup<
a->lea(
index_size, x86::ptr(indices, index_size, areIndices64b ? 3 : 2));

asmjit::Label exit = a->newLabel();
asmjit::Label error = a->newLabel();
asmjit::Label LoopRangeIndexBegin = a->newLabel();
asmjit::Label LoopRangeIndexEnd = a->newLabel();
asmjit::Label exit = a->new_label();
asmjit::Label error = a->new_label();
asmjit::Label LoopRangeIndexBegin = a->new_label();
asmjit::Label LoopRangeIndexEnd = a->new_label();

// rangeIndex loop begins (iterate output_size times)
a->bind(LoopRangeIndexBegin);
a->dec(output_size);
a->jl(LoopRangeIndexEnd);

if (normalize_by_lengths) {
asmjit::Label IfLengthsBegin = a->newLabel();
asmjit::Label IfLengthsEnd = a->newLabel();
asmjit::Label IfLengthsBegin = a->new_label();
asmjit::Label IfLengthsEnd = a->new_label();
a->bind(IfLengthsBegin);
if (use_offsets) {
a->mov(lengths_R_, x86::dword_ptr(lengths, sizeof(offsetType)));
Expand Down Expand Up @@ -548,9 +548,9 @@ GenEmbeddingSpMDMNBitLookup<
a->cmp(scratchReg1_, index_size);
a->jg(error);

asmjit::Label LoopDataIndexBegin = a->newLabel();
asmjit::Label LoopDataIndexEnd = a->newLabel();
asmjit::Label ValidIndexLabel = a->newLabel();
asmjit::Label LoopDataIndexBegin = a->new_label();
asmjit::Label LoopDataIndexEnd = a->new_label();
asmjit::Label ValidIndexLabel = a->new_label();

// dataIndex loop begins (iterate lengths_R_ times)
a->bind(LoopDataIndexBegin);
Expand Down Expand Up @@ -597,8 +597,8 @@ GenEmbeddingSpMDMNBitLookup<
int num_elem_per_byte = 8 / bit_rate;
int fused_block_size = input_stride;
if (pref_dist) {
asmjit::Label pref_dist_reset_start = a->newLabel();
asmjit::Label pref_dist_reset_end = a->newLabel();
asmjit::Label pref_dist_reset_start = a->new_label();
asmjit::Label pref_dist_reset_end = a->new_label();
// out of bound handling for prefetch
a->lea(
scratchReg2_, x86::ptr(indices, pref_dist * sizeof(indxType)));
Expand Down Expand Up @@ -629,8 +629,8 @@ GenEmbeddingSpMDMNBitLookup<
a->bind(pref_dist_reset_end);
if constexpr (ROWWISE_SPARSE) {
asmjit::Label rowwise_sparse_pref_corner_case_begin =
a->newLabel();
asmjit::Label rowwise_sparse_pref_corner_case_end = a->newLabel();
a->new_label();
asmjit::Label rowwise_sparse_pref_corner_case_end = a->new_label();
a->cmp(scratchReg2_, data_size);
a->jae(rowwise_sparse_pref_corner_case_begin);

Expand Down Expand Up @@ -941,20 +941,20 @@ GenEmbeddingSpMDMNBitLookup<
a->lea(x86::rsp, x86::ymmword_ptr(x86::rsp, vlen * sizeof(int32_t)));
}

a->emitEpilog(frame);
a->emit_epilog(frame);

// jit_fused8bitembedding_kernel fn;
typename ReturnFunctionSignature<
indxType,
offsetType,
outType,
ROWWISE_SPARSE>::jit_embedding_kernel fn;
asmjit::Error err = 0;
asmjit::Error err = asmjit::Error::kOk;
{
unique_lock<mutex> lock(rtMutex_);
err = runtime().add(&fn, &code);
}
if (err) {
if (err != asmjit::Error::kOk) {
cout << "Error: in fn add" << '\n';
return nullptr;
}
Expand Down
38 changes: 19 additions & 19 deletions src/FbgemmI64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,28 +185,28 @@ CodeGenBase<int64_t, int64_t, int64_t, int64_t>::getOrCreate(
asmjit::FuncFrame frame;
frame.init(func);

frame.setDirtyRegs(
frame.set_dirty_regs(
asmjit::RegGroup::kVec,
asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) |
asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) |
asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31));
frame.setDirtyRegs(
asmjit::Support::bit_mask<int>(0, 1, 2, 3, 4, 5, 6, 7) |
asmjit::Support::bit_mask<int>(8, 9, 10, 11, 12, 13, 14, 15) |
asmjit::Support::bit_mask<int>(16, 17, 18, 19, 20, 21, 22, 23) |
asmjit::Support::bit_mask<int>(24, 25, 26, 27, 28, 29, 30, 31));
frame.set_dirty_regs(
asmjit::RegGroup::kGp,
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
asmjit::Support::bit_mask<int>(8, 9, 10, 11, 12, 13, 14, 15));

asmjit::FuncArgsAssignment args(&func);
args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
args.assign_all(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);

args.updateFuncFrame(frame);
args.update_func_frame(frame);
frame.finalize();

a->emitProlog(frame);
a->emitArgsAssignment(frame, args);
a->emit_prolog(frame);
a->emit_args_assignment(frame, args);

asmjit::Label LoopMBlocks = a->newLabel();
asmjit::Label LoopNBlocks = a->newLabel();
asmjit::Label Loopk = a->newLabel();
asmjit::Label LoopMBlocks = a->new_label();
asmjit::Label LoopNBlocks = a->new_label();
asmjit::Label Loopk = a->new_label();

x86::Gp buffer_B_saved = a->gpz(10);
x86::Gp C_Offset = a->gpz(11);
Expand Down Expand Up @@ -308,8 +308,8 @@ CodeGenBase<int64_t, int64_t, int64_t, int64_t>::getOrCreate(
// generate code for remainder
if (mRegBlocksRem > 0) {
assert(false);
asmjit::Label LoopNRem = a->newLabel();
asmjit::Label LoopkRem = a->newLabel();
asmjit::Label LoopNRem = a->new_label();
asmjit::Label LoopkRem = a->new_label();
int rowRegs = mRegBlocksRem;

a->xor_(jIdx.r32(), jIdx.r32());
Expand Down Expand Up @@ -366,15 +366,15 @@ CodeGenBase<int64_t, int64_t, int64_t, int64_t>::getOrCreate(
a->jl(LoopNRem);
}

a->emitEpilog(frame);
a->emit_epilog(frame);

jit_micro_kernel_fp fn = nullptr;
asmjit::Error err = 0;
asmjit::Error err = asmjit::Error::kOk;
{
unique_lock<mutex> lock(rtMutex_);
err = runtime().add(&fn, &code);
}
if (err) {
if (err != asmjit::Error::kOk) {
cout << "Error: in fn add" << '\n';
return nullptr;
}
Expand Down
Loading
Loading