1+ #include < ATen/cpu/vec/vec.h>
2+ #include < ATen/cpu/vec/vec512/vec512_float8.h>
3+ #include < ATen/native/CPUBlas.h>
4+ #include < ATen/native/EmbeddingBag.h>
5+ #include < c10/util/Float8_e4m3fn.h>
6+ #include < c10/util/Unroll.h>
7+ #include < torch/all.h>
8+
9+ namespace torchao {
10+
11+ namespace {
12+
13+ #if defined(CPU_CAPABILITY_AVX512)
14+ static inline __m512 _mm512_load_e4m3_cvt_ps (const at::Float8_e4m3fn *x) {
15+ __m512 o;
16+ __m128i v = _mm_loadu_si128 (reinterpret_cast <const __m128i *>(x));
17+ at::vec::CPU_CAPABILITY::cvtfp8e4m3_fp32 (v, o);
18+ return o;
19+ }
20+ #endif
21+
22+ template <typename index_t >
23+ inline void _scaled_embedding_bag_krnl (
24+ const int64_t bs_begin, const int64_t bs_end, const int64_t num_emb,
25+ const int64_t emb_dim, const index_t last_offset, const index_t *indices,
26+ const index_t *offsets, const at::Float8_e4m3fn *weight, const double scale,
27+ float *result, const int64_t num_batch) {
28+ #if defined(CPU_CAPABILITY_AVX512)
29+ if (emb_dim % 128 == 0 ) {
30+ constexpr int64_t block_dim = 128 ;
31+ const int64_t num_blocks = emb_dim / block_dim;
32+ __m512 scale_v = _mm512_set1_ps (scale);
33+ for (int64_t b = bs_begin; b < bs_end; ++b) {
34+ __m512 x0, x1, x2, x3, x4, x5, x6, x7;
35+ int64_t start_idx = offsets[b];
36+ int64_t end_idx = ((b + 1 ) == num_batch && last_offset != -1 )
37+ ? last_offset
38+ : offsets[b + 1 ];
39+ for (int64_t block_id = 0 ; block_id < num_blocks; block_id++) {
40+ // load first indices
41+ int64_t idx = indices[start_idx] * emb_dim + block_dim * block_id;
42+ float *block_result = result + block_dim * block_id;
43+ x0 = _mm512_load_e4m3_cvt_ps (&weight[idx]);
44+ x1 = _mm512_load_e4m3_cvt_ps (&weight[idx + 16 ]);
45+ x2 = _mm512_load_e4m3_cvt_ps (&weight[idx + 32 ]);
46+ x3 = _mm512_load_e4m3_cvt_ps (&weight[idx + 48 ]);
47+ x4 = _mm512_load_e4m3_cvt_ps (&weight[idx + 64 ]);
48+ x5 = _mm512_load_e4m3_cvt_ps (&weight[idx + 80 ]);
49+ x6 = _mm512_load_e4m3_cvt_ps (&weight[idx + 96 ]);
50+ x7 = _mm512_load_e4m3_cvt_ps (&weight[idx + 112 ]);
51+ for (int64_t j = start_idx + 1 ; j < end_idx; ++j) {
52+ // add following idx
53+ idx = indices[j] * emb_dim + block_dim * block_id;
54+ x0 = _mm512_add_ps (x0, _mm512_load_e4m3_cvt_ps (&weight[idx]));
55+ x1 = _mm512_add_ps (x1, _mm512_load_e4m3_cvt_ps (&weight[idx + 16 ]));
56+ x2 = _mm512_add_ps (x2, _mm512_load_e4m3_cvt_ps (&weight[idx + 32 ]));
57+ x3 = _mm512_add_ps (x3, _mm512_load_e4m3_cvt_ps (&weight[idx + 48 ]));
58+ x4 = _mm512_add_ps (x4, _mm512_load_e4m3_cvt_ps (&weight[idx + 64 ]));
59+ x5 = _mm512_add_ps (x5, _mm512_load_e4m3_cvt_ps (&weight[idx + 80 ]));
60+ x6 = _mm512_add_ps (x6, _mm512_load_e4m3_cvt_ps (&weight[idx + 96 ]));
61+ x7 = _mm512_add_ps (x7, _mm512_load_e4m3_cvt_ps (&weight[idx + 112 ]));
62+ }
63+ x0 = _mm512_mul_ps (x0, scale_v);
64+ x1 = _mm512_mul_ps (x1, scale_v);
65+ x2 = _mm512_mul_ps (x2, scale_v);
66+ x3 = _mm512_mul_ps (x3, scale_v);
67+ x4 = _mm512_mul_ps (x4, scale_v);
68+ x5 = _mm512_mul_ps (x5, scale_v);
69+ x6 = _mm512_mul_ps (x6, scale_v);
70+ x7 = _mm512_mul_ps (x7, scale_v);
71+ // store
72+ _mm512_store_ps (block_result, x0);
73+ _mm512_store_ps (block_result + 16 , x1);
74+ _mm512_store_ps (block_result + 32 , x2);
75+ _mm512_store_ps (block_result + 48 , x3);
76+ _mm512_store_ps (block_result + 64 , x4);
77+ _mm512_store_ps (block_result + 80 , x5);
78+ _mm512_store_ps (block_result + 96 , x6);
79+ _mm512_store_ps (block_result + 112 , x7);
80+ }
81+ result += num_emb * emb_dim;
82+ }
83+ return ;
84+ }
85+ #endif
86+ for (int64_t b = bs_begin; b < bs_end; ++b) {
87+ int64_t start_idx = offsets[b];
88+ int64_t end_idx = ((b + 1 ) == num_batch && last_offset != -1 )
89+ ? last_offset
90+ : offsets[b + 1 ];
91+ for (int64_t d = 0 ; d < emb_dim; d++) {
92+ int64_t idx = indices[start_idx] * emb_dim;
93+ float value = float (weight[idx + d]);
94+ for (int64_t j = start_idx + 1 ; j < end_idx; ++j) {
95+ idx = indices[j] * emb_dim;
96+ value += float (weight[idx + d]);
97+ }
98+ value = value * scale;
99+ result[d] = value;
100+ }
101+ result += num_emb * emb_dim;
102+ }
103+ }
104+
105+ template <typename index_t , typename data_t >
106+ void _scaled_embedding_bag (float *o_ptr, data_t *w_ptr, index_t *indices_ptr,
107+ index_t *offsets_ptr, int64_t num_batch,
108+ int64_t emb_dim, index_t last_offset, double w_scale,
109+ double o_scale) {
110+ constexpr int64_t b_block = 512 ;
111+ const int64_t n_b_blocks = (num_batch - 1 ) / b_block + 1 ;
112+ w_scale /= o_scale;
113+ const int64_t num_emb = 1 ;
114+ #pragma omp parallel for collapse(2)
115+ for (int64_t b = 0 ; b < n_b_blocks; ++b) {
116+ for (int64_t n = 0 ; n < num_emb; ++n) {
117+ const int64_t bs_begin = b * b_block;
118+ const int64_t bs_end = std::min (num_batch, (b + 1 ) * b_block);
119+ float *r = &o_ptr[b * b_block * num_emb * emb_dim + n * emb_dim];
120+ // avoid offsets not include last batch
121+ _scaled_embedding_bag_krnl (bs_begin, bs_end, num_emb, emb_dim,
122+ last_offset, indices_ptr, offsets_ptr, w_ptr,
123+ w_scale, r, num_batch);
124+ }
125+ }
126+ }
127+
128+ at::Tensor _scaled_embedding_bag_impl (const at::Tensor &qweight,
129+ const at::Tensor &indices,
130+ const at::Tensor &offsets,
131+ const at::Tensor &w_scales,
132+ double o_scale, const int64_t mode,
133+ bool include_last_offset) {
134+ // Only support include_last_offset == True and mode ==
135+ // at::native::EmbeddingBagMode::SUM
136+ // TODO: Support more case
137+ TORCH_CHECK (include_last_offset,
138+ " _scaled_embedding_bag: only suppport include_last_offset" );
139+ TORCH_CHECK (mode == at::native::EmbeddingBagMode::SUM,
140+ " _scaled_embedding_bag: only suppport sum mode" );
141+ int64_t batch_size =
142+ include_last_offset ? offsets.size (0 ) - 1 : offsets.size (0 );
143+ int64_t emb_dim = qweight.size (1 );
144+
145+ auto index_type = indices.scalar_type ();
146+ auto qtype = qweight.scalar_type ();
147+ float w_scale = w_scales.data_ptr <float >()[0 ];
148+
149+ TORCH_CHECK (indices.is_contiguous () && offsets.is_contiguous (),
150+ " _scaled_embedding_bag: only accept contiguous input" );
151+ TORCH_CHECK (
152+ offsets.scalar_type () == index_type,
153+ " _scaled_embedding_bag: index and offset must be of the same type" );
154+ TORCH_CHECK (qweight.is_contiguous (),
155+ " _scaled_embedding_bag: only accept contiguous weight" );
156+ TORCH_CHECK (qweight.dim () == 2 ,
157+ " _scaled_embedding_bag: only accept weight with dim == 2" );
158+ TORCH_CHECK (qweight.scalar_type () == c10::ScalarType::Float8_e4m3fn,
159+ " _scaled_embedding_bag: only support e4m3fn weight" )
160+ // handle last offsets
161+ int64_t last_offset = indices.numel ();
162+
163+ at::Tensor output =
164+ at::empty ({batch_size, emb_dim}, qweight.options ().dtype (at::kFloat ));
165+ AT_DISPATCH_INDEX_TYPES (indices.scalar_type (), " embeddingbag_cat" , [&] {
166+ at::Float8_e4m3fn *qweight_ptr = qweight.data_ptr <at::Float8_e4m3fn>();
167+ index_t *indices_ptr = indices.data_ptr <index_t >();
168+ index_t *offsets_ptr = offsets.data_ptr <index_t >();
169+ float *output_ptr = output.data_ptr <float >();
170+ _scaled_embedding_bag<index_t , at::Float8_e4m3fn>(
171+ output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size, emb_dim,
172+ last_offset, w_scale, o_scale);
173+ });
174+ return output;
175+ }
176+
177+ } // anonymous namespace
178+
179+ TORCH_LIBRARY_IMPL (torchao, CPU, m) {
180+ m.impl (" torchao::_scaled_embedding_bag" , &_scaled_embedding_bag_impl);
181+ }
182+
183+ } // namespace torchao
0 commit comments