Skip to content

Commit 2a53216

Browse files
authored
[CPU][float8] Add scaled_embedding_bag kernel (#2686)
* add embeddingbag kernel * switch to use cvtfp8e4m3_fp32 * improve code style * rm unused buf * mv ut to test/test_ops.py * refine kernel * add test case * add more assert * add more test case * fix accuracy issue * rename qembeddingbag to _scaled_embedding_bag * improve code style * change atol and rtol
1 parent 6e9bf26 commit 2a53216

File tree

3 files changed

+266
-0
lines changed

3 files changed

+266
-0
lines changed

test/test_ops.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -764,5 +764,69 @@ def test_swizzle_mm():
764764
)
765765

766766

767+
EMBEDINGBAG_MULTIHOT_SIZES = [1, 2, 3, 10]
768+
EMBEDINGBAG_BAG_SIZES = [1, 2, 128, 1024]
769+
EMBEDINGBAG_VECTOR_SIZES = [1, 128, 512]
770+
EMBEDINGBAG_INDEX_DTYPES = [torch.int64, torch.int32]
771+
772+
EMBEDINGBAG_TEST_PARAMS = list(
773+
itertools.product(
774+
EMBEDINGBAG_MULTIHOT_SIZES,
775+
EMBEDINGBAG_BAG_SIZES,
776+
EMBEDINGBAG_VECTOR_SIZES,
777+
EMBEDINGBAG_INDEX_DTYPES,
778+
)
779+
)
780+
781+
782+
@pytest.mark.skipif(
783+
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
784+
reason="cpp kernels not built",
785+
)
786+
@pytest.mark.parametrize(
787+
"multi_hot, batch_size, vector_size, index_type",
788+
EMBEDINGBAG_TEST_PARAMS,
789+
ids=str,
790+
)
791+
def test_scaled_embedding_bag_cpu(multi_hot, batch_size, vector_size, index_type):
792+
qtype = torch.float8_e4m3fn
793+
dtype = torch.float32
794+
weight_scale = torch.tensor([2.0])
795+
include_last_offset = True
796+
mode = "sum"
797+
798+
if mode == "sum":
799+
mode_enum = 0
800+
elif mode == "mean":
801+
mode_enum = 1
802+
elif mode == "max":
803+
mode_enum = 2
804+
indices = torch.randint(1000, (batch_size * multi_hot,)).to(index_type)
805+
offsets = torch.arange(0, (batch_size + 1) * multi_hot, multi_hot).to(index_type)
806+
807+
m = torch.nn.EmbeddingBag(
808+
1000,
809+
vector_size,
810+
mode=mode,
811+
dtype=dtype,
812+
include_last_offset=include_last_offset,
813+
)
814+
fp8_weight = m.weight.data.to(qtype)
815+
m.weight.data = fp8_weight.to(m.weight.dtype)
816+
817+
with torch.no_grad():
818+
refe_out = m.forward(indices, offsets) * weight_scale
819+
test_out = torch.ops.torchao._scaled_embedding_bag(
820+
fp8_weight,
821+
indices,
822+
offsets,
823+
weight_scale,
824+
1.0,
825+
mode_enum,
826+
include_last_offset,
827+
).to(dtype)
828+
torch.testing.assert_close(refe_out, test_out, atol=1e-5, rtol=1e-5)
829+
830+
767831
if __name__ == "__main__":
768832
pytest.main(sys.argv)
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
#include <ATen/cpu/vec/vec.h>
2+
#include <ATen/cpu/vec/vec512/vec512_float8.h>
3+
#include <ATen/native/CPUBlas.h>
4+
#include <ATen/native/EmbeddingBag.h>
5+
#include <c10/util/Float8_e4m3fn.h>
6+
#include <c10/util/Unroll.h>
7+
#include <torch/all.h>
8+
9+
namespace torchao {
10+
11+
namespace {
12+
13+
#if defined(CPU_CAPABILITY_AVX512)
14+
static inline __m512 _mm512_load_e4m3_cvt_ps(const at::Float8_e4m3fn *x) {
15+
__m512 o;
16+
__m128i v = _mm_loadu_si128(reinterpret_cast<const __m128i *>(x));
17+
at::vec::CPU_CAPABILITY::cvtfp8e4m3_fp32(v, o);
18+
return o;
19+
}
20+
#endif
21+
22+
template <typename index_t>
23+
inline void _scaled_embedding_bag_krnl(
24+
const int64_t bs_begin, const int64_t bs_end, const int64_t num_emb,
25+
const int64_t emb_dim, const index_t last_offset, const index_t *indices,
26+
const index_t *offsets, const at::Float8_e4m3fn *weight, const double scale,
27+
float *result, const int64_t num_batch) {
28+
#if defined(CPU_CAPABILITY_AVX512)
29+
if (emb_dim % 128 == 0) {
30+
constexpr int64_t block_dim = 128;
31+
const int64_t num_blocks = emb_dim / block_dim;
32+
__m512 scale_v = _mm512_set1_ps(scale);
33+
for (int64_t b = bs_begin; b < bs_end; ++b) {
34+
__m512 x0, x1, x2, x3, x4, x5, x6, x7;
35+
int64_t start_idx = offsets[b];
36+
int64_t end_idx = ((b + 1) == num_batch && last_offset != -1)
37+
? last_offset
38+
: offsets[b + 1];
39+
for (int64_t block_id = 0; block_id < num_blocks; block_id++) {
40+
// load first indices
41+
int64_t idx = indices[start_idx] * emb_dim + block_dim * block_id;
42+
float *block_result = result + block_dim * block_id;
43+
x0 = _mm512_load_e4m3_cvt_ps(&weight[idx]);
44+
x1 = _mm512_load_e4m3_cvt_ps(&weight[idx + 16]);
45+
x2 = _mm512_load_e4m3_cvt_ps(&weight[idx + 32]);
46+
x3 = _mm512_load_e4m3_cvt_ps(&weight[idx + 48]);
47+
x4 = _mm512_load_e4m3_cvt_ps(&weight[idx + 64]);
48+
x5 = _mm512_load_e4m3_cvt_ps(&weight[idx + 80]);
49+
x6 = _mm512_load_e4m3_cvt_ps(&weight[idx + 96]);
50+
x7 = _mm512_load_e4m3_cvt_ps(&weight[idx + 112]);
51+
for (int64_t j = start_idx + 1; j < end_idx; ++j) {
52+
// add following idx
53+
idx = indices[j] * emb_dim + block_dim * block_id;
54+
x0 = _mm512_add_ps(x0, _mm512_load_e4m3_cvt_ps(&weight[idx]));
55+
x1 = _mm512_add_ps(x1, _mm512_load_e4m3_cvt_ps(&weight[idx + 16]));
56+
x2 = _mm512_add_ps(x2, _mm512_load_e4m3_cvt_ps(&weight[idx + 32]));
57+
x3 = _mm512_add_ps(x3, _mm512_load_e4m3_cvt_ps(&weight[idx + 48]));
58+
x4 = _mm512_add_ps(x4, _mm512_load_e4m3_cvt_ps(&weight[idx + 64]));
59+
x5 = _mm512_add_ps(x5, _mm512_load_e4m3_cvt_ps(&weight[idx + 80]));
60+
x6 = _mm512_add_ps(x6, _mm512_load_e4m3_cvt_ps(&weight[idx + 96]));
61+
x7 = _mm512_add_ps(x7, _mm512_load_e4m3_cvt_ps(&weight[idx + 112]));
62+
}
63+
x0 = _mm512_mul_ps(x0, scale_v);
64+
x1 = _mm512_mul_ps(x1, scale_v);
65+
x2 = _mm512_mul_ps(x2, scale_v);
66+
x3 = _mm512_mul_ps(x3, scale_v);
67+
x4 = _mm512_mul_ps(x4, scale_v);
68+
x5 = _mm512_mul_ps(x5, scale_v);
69+
x6 = _mm512_mul_ps(x6, scale_v);
70+
x7 = _mm512_mul_ps(x7, scale_v);
71+
// store
72+
_mm512_store_ps(block_result, x0);
73+
_mm512_store_ps(block_result + 16, x1);
74+
_mm512_store_ps(block_result + 32, x2);
75+
_mm512_store_ps(block_result + 48, x3);
76+
_mm512_store_ps(block_result + 64, x4);
77+
_mm512_store_ps(block_result + 80, x5);
78+
_mm512_store_ps(block_result + 96, x6);
79+
_mm512_store_ps(block_result + 112, x7);
80+
}
81+
result += num_emb * emb_dim;
82+
}
83+
return;
84+
}
85+
#endif
86+
for (int64_t b = bs_begin; b < bs_end; ++b) {
87+
int64_t start_idx = offsets[b];
88+
int64_t end_idx = ((b + 1) == num_batch && last_offset != -1)
89+
? last_offset
90+
: offsets[b + 1];
91+
for (int64_t d = 0; d < emb_dim; d++) {
92+
int64_t idx = indices[start_idx] * emb_dim;
93+
float value = float(weight[idx + d]);
94+
for (int64_t j = start_idx + 1; j < end_idx; ++j) {
95+
idx = indices[j] * emb_dim;
96+
value += float(weight[idx + d]);
97+
}
98+
value = value * scale;
99+
result[d] = value;
100+
}
101+
result += num_emb * emb_dim;
102+
}
103+
}
104+
105+
template <typename index_t, typename data_t>
106+
void _scaled_embedding_bag(float *o_ptr, data_t *w_ptr, index_t *indices_ptr,
107+
index_t *offsets_ptr, int64_t num_batch,
108+
int64_t emb_dim, index_t last_offset, double w_scale,
109+
double o_scale) {
110+
constexpr int64_t b_block = 512;
111+
const int64_t n_b_blocks = (num_batch - 1) / b_block + 1;
112+
w_scale /= o_scale;
113+
const int64_t num_emb = 1;
114+
#pragma omp parallel for collapse(2)
115+
for (int64_t b = 0; b < n_b_blocks; ++b) {
116+
for (int64_t n = 0; n < num_emb; ++n) {
117+
const int64_t bs_begin = b * b_block;
118+
const int64_t bs_end = std::min(num_batch, (b + 1) * b_block);
119+
float *r = &o_ptr[b * b_block * num_emb * emb_dim + n * emb_dim];
120+
// avoid offsets not include last batch
121+
_scaled_embedding_bag_krnl(bs_begin, bs_end, num_emb, emb_dim,
122+
last_offset, indices_ptr, offsets_ptr, w_ptr,
123+
w_scale, r, num_batch);
124+
}
125+
}
126+
}
127+
128+
at::Tensor _scaled_embedding_bag_impl(const at::Tensor &qweight,
129+
const at::Tensor &indices,
130+
const at::Tensor &offsets,
131+
const at::Tensor &w_scales,
132+
double o_scale, const int64_t mode,
133+
bool include_last_offset) {
134+
// Only support include_last_offset == True and mode ==
135+
// at::native::EmbeddingBagMode::SUM
136+
// TODO: Support more case
137+
TORCH_CHECK(include_last_offset,
138+
"_scaled_embedding_bag: only suppport include_last_offset");
139+
TORCH_CHECK(mode == at::native::EmbeddingBagMode::SUM,
140+
"_scaled_embedding_bag: only suppport sum mode");
141+
int64_t batch_size =
142+
include_last_offset ? offsets.size(0) - 1 : offsets.size(0);
143+
int64_t emb_dim = qweight.size(1);
144+
145+
auto index_type = indices.scalar_type();
146+
auto qtype = qweight.scalar_type();
147+
float w_scale = w_scales.data_ptr<float>()[0];
148+
149+
TORCH_CHECK(indices.is_contiguous() && offsets.is_contiguous(),
150+
"_scaled_embedding_bag: only accept contiguous input");
151+
TORCH_CHECK(
152+
offsets.scalar_type() == index_type,
153+
"_scaled_embedding_bag: index and offset must be of the same type");
154+
TORCH_CHECK(qweight.is_contiguous(),
155+
"_scaled_embedding_bag: only accept contiguous weight");
156+
TORCH_CHECK(qweight.dim() == 2,
157+
"_scaled_embedding_bag: only accept weight with dim == 2");
158+
TORCH_CHECK(qweight.scalar_type() == c10::ScalarType::Float8_e4m3fn,
159+
"_scaled_embedding_bag: only support e4m3fn weight")
160+
// handle last offsets
161+
int64_t last_offset = indices.numel();
162+
163+
at::Tensor output =
164+
at::empty({batch_size, emb_dim}, qweight.options().dtype(at::kFloat));
165+
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embeddingbag_cat", [&] {
166+
at::Float8_e4m3fn *qweight_ptr = qweight.data_ptr<at::Float8_e4m3fn>();
167+
index_t *indices_ptr = indices.data_ptr<index_t>();
168+
index_t *offsets_ptr = offsets.data_ptr<index_t>();
169+
float *output_ptr = output.data_ptr<float>();
170+
_scaled_embedding_bag<index_t, at::Float8_e4m3fn>(
171+
output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size, emb_dim,
172+
last_offset, w_scale, o_scale);
173+
});
174+
return output;
175+
}
176+
177+
} // anonymous namespace
178+
179+
TORCH_LIBRARY_IMPL(torchao, CPU, m) {
180+
m.impl("torchao::_scaled_embedding_bag", &_scaled_embedding_bag_impl);
181+
}
182+
183+
} // namespace torchao

torchao/ops.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@
6868
lib.define(
6969
"da8w4_linear_cpu(Tensor input, Tensor input_scales, Tensor input_qzeros, Tensor weight, Tensor weight_scales, Tensor weight_qzeros, Tensor compensation, Tensor? bias, ScalarType output_dtype) -> Tensor"
7070
)
71+
lib.define(
72+
"_scaled_embedding_bag(Tensor qweight, Tensor indices, Tensor offsets, Tensor weight_scale, float o_scale, int mode, bool include_last_offset) -> Tensor"
73+
)
7174

7275

7376
def register_custom_op(name):
@@ -1098,3 +1101,19 @@ def _(
10981101
assert weight.dim() == 4
10991102
N = weight.size(0) * weight.size(3) * 2
11001103
return input.new_empty(*input.shape[:-1], N, dtype=out_dtype)
1104+
1105+
1106+
@register_custom_op("torchao::_scaled_embedding_bag")
1107+
def _(
1108+
qweight: Tensor,
1109+
indices: Tensor,
1110+
offsets: Tensor,
1111+
w_scales: Tensor,
1112+
o_scale: float,
1113+
mode: int,
1114+
include_last_offset: bool,
1115+
) -> Tensor:
1116+
# Only support include_last_offset == True
1117+
assert include_last_offset == True
1118+
batch_size = offsets.shape[0] - 1
1119+
return qweight.new_empty(batch_size, qweight.shape[1], dtype=qweight.dtype)

0 commit comments

Comments
 (0)