Skip to content

Commit 5236007

Browse files
kurtamohlerpytorchmergebot
authored andcommitted
[MPS] Add embedding_bag forward pass (pytorch#163012)
Part of pytorch#162270 Pull Request resolved: pytorch#163012 Approved by: https://github.com/kulinseth, https://github.com/malfet
1 parent 167ad09 commit 5236007

File tree

8 files changed

+423
-5
lines changed

8 files changed

+423
-5
lines changed

aten/src/ATen/native/EmbeddingBag.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#pragma once
12
#include <ATen/core/Tensor.h>
23
#include <ATen/Config.h>
34
#include <cstdint>
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#pragma once
2+
#include <c10/metal/common.h>
3+
4+
#ifdef __METAL__
5+
enum class EmbeddingBagMode { SUM = 0, MEAN, MAX };
6+
#else
7+
#include <ATen/native/EmbeddingBag.h>
8+
using at::native::EmbeddingBagMode;
9+
#endif
10+
11+
template <typename idx_type_t = uint32_t>
12+
struct EmbeddingBagParams {
13+
::c10::metal::array<idx_type_t, 2> weight_strides;
14+
::c10::metal::array<idx_type_t, 2> output_strides;
15+
::c10::metal::array<idx_type_t, 2> max_indices_strides;
16+
17+
idx_type_t per_sample_weights_strides;
18+
19+
idx_type_t num_indices;
20+
idx_type_t num_bags;
21+
idx_type_t feature_size;
22+
23+
EmbeddingBagMode mode;
24+
int64_t padding_idx;
25+
};
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
#include <ATen/native/mps/kernels/EmbeddingBag.h>
2+
#include <c10/metal/utils.h>
3+
#include <metal_array>
4+
#include <metal_stdlib>
5+
6+
using namespace metal;
7+
using namespace c10::metal;
8+
9+
template <EmbeddingBagMode M, typename T>
10+
struct ReductionOpInit {
11+
inline opmath_t<T> operator()() {
12+
return 0;
13+
}
14+
};
15+
16+
template <typename T>
17+
struct ReductionOpInit<EmbeddingBagMode::MAX, T> {
18+
inline opmath_t<T> operator()() {
19+
return static_cast<opmath_t<T>>(-INFINITY);
20+
}
21+
};
22+
23+
template <EmbeddingBagMode M, typename T>
24+
struct ReductionOp {
25+
inline opmath_t<T> operator()(
26+
T weight_val,
27+
opmath_t<T> out_val,
28+
uint32_t per_sample_weights_index,
29+
constant T* per_sample_weights,
30+
uint32_t per_sample_weights_strides);
31+
};
32+
33+
template <typename T>
34+
struct ReductionOp<EmbeddingBagMode::SUM, T> {
35+
inline opmath_t<T> operator()(
36+
T weight_val,
37+
opmath_t<T> out_val,
38+
uint32_t per_sample_weights_index,
39+
constant T* per_sample_weights,
40+
uint32_t per_sample_weights_strides) {
41+
if (per_sample_weights_strides) {
42+
T per_sample_weight = per_sample_weights
43+
[per_sample_weights_strides * per_sample_weights_index];
44+
return static_cast<opmath_t<T>>(per_sample_weight) *
45+
static_cast<opmath_t<T>>(weight_val) +
46+
out_val;
47+
} else {
48+
return static_cast<opmath_t<T>>(weight_val) + out_val;
49+
}
50+
}
51+
};
52+
53+
template <typename T>
54+
struct ReductionOp<EmbeddingBagMode::MEAN, T> {
55+
inline opmath_t<T> operator()(
56+
T weight_val,
57+
opmath_t<T> out_val,
58+
uint32_t,
59+
constant T*,
60+
uint32_t) {
61+
return static_cast<opmath_t<T>>(weight_val) + out_val;
62+
}
63+
};
64+
65+
template <typename T>
66+
struct ReductionOp<EmbeddingBagMode::MAX, T> {
67+
inline opmath_t<T> operator()(
68+
T weight_val,
69+
opmath_t<T> out_val,
70+
uint32_t,
71+
constant T*,
72+
uint32_t) {
73+
return max(static_cast<opmath_t<T>>(weight_val), out_val);
74+
}
75+
};
76+
77+
template <EmbeddingBagMode M, typename T>
78+
struct ReductionOpFinal {
79+
inline T operator()(opmath_t<T> val, uint32_t) {
80+
return static_cast<T>(val);
81+
}
82+
};
83+
84+
template <typename T>
85+
struct ReductionOpFinal<EmbeddingBagMode::MEAN, T> {
86+
inline T operator()(opmath_t<T> val, uint32_t count) {
87+
auto out = val / count;
88+
return static_cast<T>((count == 0) ? 0 : out);
89+
}
90+
};
91+
92+
template <typename T>
93+
struct ReductionOpFinal<EmbeddingBagMode::MAX, T> {
94+
inline T operator()(opmath_t<T> val, uint32_t count) {
95+
return static_cast<T>((count == 0) ? 0 : val);
96+
}
97+
};
98+
99+
template <EmbeddingBagMode M, typename T, typename I>
100+
void embedding_bag_impl(
101+
constant T* weight,
102+
constant I* indices,
103+
constant I* offsets,
104+
constant T* per_sample_weights,
105+
device T* output,
106+
device I* offset2bag,
107+
device I* bag_size,
108+
device I* max_indices,
109+
constant EmbeddingBagParams<uint32_t>& params,
110+
uint tid) {
111+
auto num_indices = params.num_indices;
112+
auto num_bags = params.num_bags;
113+
auto feature_size = params.feature_size;
114+
auto padding_idx = params.padding_idx;
115+
auto per_sample_weights_strides = params.per_sample_weights_strides;
116+
constant auto& output_strides = params.output_strides;
117+
constant auto& weight_strides = params.weight_strides;
118+
constant auto& max_indices_strides = params.max_indices_strides;
119+
120+
auto bag_idx = tid / feature_size;
121+
auto feature_idx = tid % feature_size;
122+
123+
output += bag_idx * output_strides[0] + feature_idx * output_strides[1];
124+
125+
uint32_t offsets_end = min(bag_idx + 1, num_bags - 1);
126+
bool is_last_bag = bag_idx + 1 == num_bags;
127+
uint32_t indices_start = static_cast<uint32_t>(offsets[bag_idx]);
128+
uint32_t indices_end = is_last_bag * (num_indices) +
129+
(!is_last_bag) * (static_cast<uint32_t>(offsets[offsets_end]));
130+
131+
auto out_val = ReductionOpInit<M, T>()();
132+
133+
uint32_t bag_size_ = 0;
134+
135+
for (uint32_t indices_idx = indices_start; indices_idx < indices_end;
136+
indices_idx++) {
137+
I weight_idx = indices[indices_idx];
138+
bool pad = (weight_idx == padding_idx);
139+
T weight_val = weight
140+
[static_cast<uint32_t>(weight_idx) * weight_strides[0] +
141+
feature_idx * weight_strides[1]];
142+
143+
bag_size_ += static_cast<uint32_t>(!pad);
144+
145+
auto tmp_val = ReductionOp<M, T>()(
146+
weight_val,
147+
out_val,
148+
indices_idx,
149+
per_sample_weights,
150+
per_sample_weights_strides);
151+
152+
out_val = pad ? out_val : tmp_val;
153+
}
154+
155+
*output = ReductionOpFinal<M, T>()(out_val, bag_size_);
156+
}
157+
158+
#define DISPATCH_IMPL(MODE) \
159+
return embedding_bag_impl<MODE>( \
160+
weight, \
161+
indices, \
162+
offsets, \
163+
per_sample_weights, \
164+
output, \
165+
offset2bag, \
166+
bag_size, \
167+
max_indices, \
168+
params, \
169+
tid)
170+
171+
template <typename T, typename I>
172+
kernel void embedding_bag(
173+
constant T* weight [[buffer(0)]],
174+
constant I* indices [[buffer(1)]],
175+
constant I* offsets [[buffer(2)]],
176+
constant T* per_sample_weights [[buffer(3)]],
177+
device T* output [[buffer(4)]],
178+
device I* offset2bag [[buffer(5)]],
179+
device I* bag_size [[buffer(6)]],
180+
device I* max_indices [[buffer(7)]],
181+
constant EmbeddingBagParams<uint32_t>& params [[buffer(8)]],
182+
uint tid [[thread_position_in_grid]]) {
183+
switch (params.mode) {
184+
case EmbeddingBagMode::SUM:
185+
DISPATCH_IMPL(EmbeddingBagMode::SUM);
186+
case EmbeddingBagMode::MEAN:
187+
DISPATCH_IMPL(EmbeddingBagMode::MEAN);
188+
case EmbeddingBagMode::MAX:
189+
DISPATCH_IMPL(EmbeddingBagMode::MAX);
190+
}
191+
}
192+
193+
#define REGISTER_EMBEDDING_BAG_OP(T, I) \
194+
template [[host_name("embedding_bag_" #T "_" #I)]] \
195+
kernel void embedding_bag<T, I>( \
196+
constant T * weight [[buffer(0)]], \
197+
constant I * indices [[buffer(1)]], \
198+
constant I * offsets [[buffer(2)]], \
199+
constant T * per_sample_weights [[buffer(3)]], \
200+
device T * output [[buffer(4)]], \
201+
device I * offset2bag [[buffer(5)]], \
202+
device I * bag_size [[buffer(6)]], \
203+
device I * max_indices [[buffer(7)]], \
204+
constant EmbeddingBagParams<uint32_t> & params [[buffer(8)]], \
205+
uint tid [[thread_position_in_grid]]);
206+
207+
REGISTER_EMBEDDING_BAG_OP(float, int);
208+
REGISTER_EMBEDDING_BAG_OP(float, long);
209+
REGISTER_EMBEDDING_BAG_OP(half, int);
210+
REGISTER_EMBEDDING_BAG_OP(half, long);
211+
REGISTER_EMBEDDING_BAG_OP(bfloat, int);
212+
REGISTER_EMBEDDING_BAG_OP(bfloat, long);

0 commit comments

Comments
 (0)