@@ -17,87 +17,40 @@ namespace native {
17
17
using Tensor = exec_aten::Tensor;
18
18
using ScalarType = exec_aten::ScalarType;
19
19
using Scalar = exec_aten::Scalar;
20
- namespace {
21
20
22
- /* *
23
- * Fills the `out` with values of `self` or `value` based on mask.
24
- *
25
- * Assumes that the tensors are contiguous, are the same shape,
26
- * input and output have the same time and mask is tensor of bools.
27
- * CTYPE should be the C type (like `float` or `int`) that matches
28
- * the dtype of the tensors.
29
- */
30
- template <class CTYPE >
31
- void masked_fill_kernel (
32
- const Tensor& self,
33
- const Tensor& mask,
34
- const Scalar& value,
35
- Tensor& out) {
36
- ET_DCHECK (self.numel () == mask.numel () && self.numel () == out.numel ());
37
- CTYPE value_v = 0 ;
38
- bool ok = utils::extract_scalar (value, &value_v);
39
- ET_CHECK_MSG (ok, " Invalid fill value: wrong type or out of range" );
40
- const size_t n = self.numel ();
41
- const auto data_self = self.const_data_ptr <CTYPE>();
42
- const auto data_mask = mask.const_data_ptr <bool >();
43
- auto data_out = out.mutable_data_ptr <CTYPE>();
44
- for (size_t i = 0 ; i < n; ++i) {
45
- data_out[i] = data_mask[i] ? value_v : data_self[i];
46
- }
47
- }
48
-
49
- } // namespace
50
-
51
- /* *
52
- * Copies `self` to `out` masking some elemnts with `value`.
53
- *
54
- * Asserts that `mask` tensor can be broadcasted to `self`, self and out should
55
- * have same dtype and size, and mask should be boolean tensor.
56
- *
57
- * masked_fill_Scalar_out(Tensor self, Tensor other, *, Scalar alpha=1.0,
58
- * Tensor(a!) out) -> Tensor(a!)
59
- */
60
21
Tensor& masked_fill_scalar_out (
61
22
RuntimeContext& ctx,
62
- const Tensor& self ,
23
+ const Tensor& in ,
63
24
const Tensor& mask,
64
25
const Scalar& value,
65
26
Tensor& out) {
66
- ET_CHECK_MSG (
67
- tensor_is_broadcastable_to (mask, self),
68
- " masked_fill_scalar_out operateor can not broadcast mask to self" );
27
+ (void )ctx;
69
28
70
- // The mask needs to be broadcasted iff its size differnet from the target one
71
- // (self.size())
72
- bool broadcasted = !self.sizes ().equals (mask.sizes ());
73
- const Tensor& broadcast_mask =
74
- broadcasted ? torch::executor::broadcast_tensor (mask, self) : mask;
29
+ ET_KERNEL_CHECK (ctx, tensors_have_same_dtype (in, out), InvalidArgument, out);
75
30
76
- torch::executor::Error err = resize_tensor (out, self.sizes ());
77
- ET_CHECK_MSG (
78
- err == torch::executor::Error::Ok,
79
- " Failed to resize out Tensor in masked_fill_scalar_out" );
31
+ ScalarType in_type = in.scalar_type ();
32
+ ScalarType mask_type = mask.scalar_type ();
33
+ ScalarType val_type = utils::get_scalar_dtype (value);
80
34
81
- ET_CHECK_SAME_SHAPE_AND_DTYPE2 (self, out);
82
- ET_CHECK_SAME_SHAPE2 (self, broadcast_mask);
83
- ET_CHECK_MSG (
84
- broadcast_mask.scalar_type () == ScalarType::Bool, " Unexpected mask type" );
35
+ ET_KERNEL_CHECK (ctx, mask_type == ScalarType::Bool, InvalidArgument, out);
85
36
86
- #define MASKED_FILL (ctype, dtype ) \
87
- case ScalarType::dtype: \
88
- masked_fill_kernel<ctype>(self, broadcast_mask, value, out); \
89
- break ;
37
+ resize_to_broadcast_target_size (in, mask, out);
90
38
91
- switch (self. scalar_type () ) {
92
- ET_FORALL_REAL_TYPES_AND (Bool, MASKED_FILL)
93
- default :
94
- ET_CHECK_MSG ( false , " Unhandled dtype %hhd " , self. scalar_type () );
95
- }
39
+ ET_SWITCH_REAL_TYPES_AND (Bool, in_type, ctx, __func__, CTYPE, [&]( ) {
40
+ ET_SWITCH_REAL_TYPES_AND (Bool, val_type, ctx, __func__, CTYPE_VAL, [&]() {
41
+ CTYPE_VAL value_v;
42
+ ET_EXTRACT_SCALAR (value, value_v );
43
+ CTYPE val = static_cast <CTYPE>(value_v);
96
44
97
- #undef MASKED_FILL
98
- if (broadcasted) {
99
- free_broadcast_tensor (broadcast_mask);
100
- }
45
+ apply_binary_elementwise_fn<CTYPE, bool , CTYPE>(
46
+ [val](const CTYPE val_in, const bool val_mask) {
47
+ return val_mask ? val : val_in;
48
+ },
49
+ in,
50
+ mask,
51
+ out);
52
+ });
53
+ });
101
54
102
55
return out;
103
56
}
0 commit comments