@@ -229,15 +229,29 @@ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn(
229229 return nullptr ;
230230}
231231
232+ bool check_tensor_dtype (
233+ const Tensor t,
234+ SupportedTensorDtypes dtypes,
235+ const ScalarType compute_type);
236+
232237} // namespace internal
233238
234239template <typename CTYPE_COMMON, const char * op_name, typename Op>
235240inline void apply_unitensor_elementwise_fn (
236241 const Op& compute_fun,
242+ KernelRuntimeContext& ctx,
237243 const Tensor& a,
238244 SupportedTensorDtypes a_dtypes,
239245 const Tensor& out,
240246 SupportedTensorDtypes out_dtypes) {
247+ constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
248+
249+ ET_KERNEL_CHECK (
250+ ctx,
251+ (internal::check_tensor_dtype (a, a_dtypes, compute_type) &&
252+ internal::check_tensor_dtype (out, out_dtypes, compute_type)),
253+ InvalidArgument, );
254+
241255 const auto load_a_to_common =
242256 internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
243257 const auto store_common_to_out =
@@ -263,12 +277,22 @@ inline void apply_unitensor_elementwise_fn(
263277template <typename CTYPE_COMMON, const char * op_name, typename Op>
264278inline void apply_bitensor_elementwise_fn (
265279 const Op& compute_fun,
280+ KernelRuntimeContext& ctx,
266281 const Tensor& a,
267282 SupportedTensorDtypes a_dtypes,
268283 const Tensor& b,
269284 SupportedTensorDtypes b_dtypes,
270285 const Tensor& out,
271286 SupportedTensorDtypes out_dtypes) {
287+ constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
288+
289+ ET_KERNEL_CHECK (
290+ ctx,
291+ (internal::check_tensor_dtype (a, a_dtypes, compute_type) &&
292+ internal::check_tensor_dtype (b, b_dtypes, compute_type) &&
293+ internal::check_tensor_dtype (out, out_dtypes, compute_type)),
294+ InvalidArgument, );
295+
272296 const bool a_is_broadcasted = !out.sizes ().equals (a.sizes ());
273297 const bool b_is_broadcasted = !out.sizes ().equals (b.sizes ());
274298 const bool any_is_broadcasted = (a_is_broadcasted || b_is_broadcasted);
@@ -312,9 +336,9 @@ inline void apply_bitensor_elementwise_fn(
312336}
313337
314338/* *
315- * Useful for tri-tensor elementwise operators. For each element of the inputs,
316- * perform a computation and write to the corresponding element of the output.
317- * Tensor broadcasting is applied wherever it is required.
339+ * Useful for tri-tensor elementwise operators. For each element of the
340+ * inputs, perform a computation and write to the corresponding element of the
341+ * output. Tensor broadcasting is applied wherever it is required.
318342 *
319343 * In order to mitigate build time cost (straightforwardly |CTYPE_A| *
320344 * |CTYPE_B| * |CTYPE_C| * |CTYPE_OUT|), all arguments to compute_fun
@@ -334,6 +358,7 @@ inline void apply_bitensor_elementwise_fn(
334358template <typename CTYPE_COMMON, const char * op_name, typename Op>
335359inline void apply_tritensor_elementwise_fn (
336360 const Op& compute_fun,
361+ KernelRuntimeContext& ctx,
337362 const Tensor& a,
338363 SupportedTensorDtypes a_dtypes,
339364 const Tensor& b,
@@ -342,6 +367,16 @@ inline void apply_tritensor_elementwise_fn(
342367 SupportedTensorDtypes c_dtypes,
343368 const Tensor& out,
344369 SupportedTensorDtypes out_dtypes) {
370+ constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
371+
372+ ET_KERNEL_CHECK (
373+ ctx,
374+ (internal::check_tensor_dtype (a, a_dtypes, compute_type) &&
375+ internal::check_tensor_dtype (b, b_dtypes, compute_type) &&
376+ internal::check_tensor_dtype (c, c_dtypes, compute_type) &&
377+ internal::check_tensor_dtype (out, out_dtypes, compute_type)),
378+ InvalidArgument, );
379+
345380 const bool a_is_broadcasted = !out.sizes ().equals (a.sizes ());
346381 const bool b_is_broadcasted = !out.sizes ().equals (b.sizes ());
347382 const bool c_is_broadcasted = !out.sizes ().equals (c.sizes ());
0 commit comments