Skip to content

Commit bd95cf9

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
ATen Compliance: masked_fill (Dtype, Shape & Broadcast)
Reviewed By: kirklandsign Differential Revision: D48276291 fbshipit-source-id: 656f39a8303e5c5fb231f8f421f5881b4d2721e0
1 parent 848b8bd commit bd95cf9

File tree

1 file changed

+22
-69
lines changed

1 file changed

+22
-69
lines changed

kernels/portable/cpu/op_masked_fill.cpp

Lines changed: 22 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -17,87 +17,40 @@ namespace native {
1717
using Tensor = exec_aten::Tensor;
1818
using ScalarType = exec_aten::ScalarType;
1919
using Scalar = exec_aten::Scalar;
20-
namespace {
2120

22-
/**
23-
* Fills the `out` with values of `self` or `value` based on mask.
24-
*
25-
* Assumes that the tensors are contiguous, are the same shape,
26-
* input and output have the same time and mask is tensor of bools.
27-
* CTYPE should be the C type (like `float` or `int`) that matches
28-
* the dtype of the tensors.
29-
*/
30-
template <class CTYPE>
31-
void masked_fill_kernel(
32-
const Tensor& self,
33-
const Tensor& mask,
34-
const Scalar& value,
35-
Tensor& out) {
36-
ET_DCHECK(self.numel() == mask.numel() && self.numel() == out.numel());
37-
CTYPE value_v = 0;
38-
bool ok = utils::extract_scalar(value, &value_v);
39-
ET_CHECK_MSG(ok, "Invalid fill value: wrong type or out of range");
40-
const size_t n = self.numel();
41-
const auto data_self = self.const_data_ptr<CTYPE>();
42-
const auto data_mask = mask.const_data_ptr<bool>();
43-
auto data_out = out.mutable_data_ptr<CTYPE>();
44-
for (size_t i = 0; i < n; ++i) {
45-
data_out[i] = data_mask[i] ? value_v : data_self[i];
46-
}
47-
}
48-
49-
} // namespace
50-
51-
/**
52-
* Copies `self` to `out` masking some elemnts with `value`.
53-
*
54-
* Asserts that `mask` tensor can be broadcasted to `self`, self and out should
55-
* have same dtype and size, and mask should be boolean tensor.
56-
*
57-
* masked_fill_Scalar_out(Tensor self, Tensor other, *, Scalar alpha=1.0,
58-
* Tensor(a!) out) -> Tensor(a!)
59-
*/
6021
Tensor& masked_fill_scalar_out(
6122
RuntimeContext& ctx,
62-
const Tensor& self,
23+
const Tensor& in,
6324
const Tensor& mask,
6425
const Scalar& value,
6526
Tensor& out) {
66-
ET_CHECK_MSG(
67-
tensor_is_broadcastable_to(mask, self),
68-
"masked_fill_scalar_out operateor can not broadcast mask to self");
27+
(void)ctx;
6928

70-
// The mask needs to be broadcasted iff its size differnet from the target one
71-
// (self.size())
72-
bool broadcasted = !self.sizes().equals(mask.sizes());
73-
const Tensor& broadcast_mask =
74-
broadcasted ? torch::executor::broadcast_tensor(mask, self) : mask;
29+
ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);
7530

76-
torch::executor::Error err = resize_tensor(out, self.sizes());
77-
ET_CHECK_MSG(
78-
err == torch::executor::Error::Ok,
79-
"Failed to resize out Tensor in masked_fill_scalar_out");
31+
ScalarType in_type = in.scalar_type();
32+
ScalarType mask_type = mask.scalar_type();
33+
ScalarType val_type = utils::get_scalar_dtype(value);
8034

81-
ET_CHECK_SAME_SHAPE_AND_DTYPE2(self, out);
82-
ET_CHECK_SAME_SHAPE2(self, broadcast_mask);
83-
ET_CHECK_MSG(
84-
broadcast_mask.scalar_type() == ScalarType::Bool, "Unexpected mask type");
35+
ET_KERNEL_CHECK(ctx, mask_type == ScalarType::Bool, InvalidArgument, out);
8536

86-
#define MASKED_FILL(ctype, dtype) \
87-
case ScalarType::dtype: \
88-
masked_fill_kernel<ctype>(self, broadcast_mask, value, out); \
89-
break;
37+
resize_to_broadcast_target_size(in, mask, out);
9038

91-
switch (self.scalar_type()) {
92-
ET_FORALL_REAL_TYPES_AND(Bool, MASKED_FILL)
93-
default:
94-
ET_CHECK_MSG(false, "Unhandled dtype %hhd", self.scalar_type());
95-
}
39+
ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, __func__, CTYPE, [&]() {
40+
ET_SWITCH_REAL_TYPES_AND(Bool, val_type, ctx, __func__, CTYPE_VAL, [&]() {
41+
CTYPE_VAL value_v;
42+
ET_EXTRACT_SCALAR(value, value_v);
43+
CTYPE val = static_cast<CTYPE>(value_v);
9644

97-
#undef MASKED_FILL
98-
if (broadcasted) {
99-
free_broadcast_tensor(broadcast_mask);
100-
}
45+
apply_binary_elementwise_fn<CTYPE, bool, CTYPE>(
46+
[val](const CTYPE val_in, const bool val_mask) {
47+
return val_mask ? val : val_in;
48+
},
49+
in,
50+
mask,
51+
out);
52+
});
53+
});
10154

10255
return out;
10356
}

0 commit comments

Comments
 (0)