diff --git a/include/fusilli/attributes/pointwise_attributes.h b/include/fusilli/attributes/pointwise_attributes.h index cc39108b..9d433745 100644 --- a/include/fusilli/attributes/pointwise_attributes.h +++ b/include/fusilli/attributes/pointwise_attributes.h @@ -51,7 +51,7 @@ namespace fusilli { /* OP(LOG) */ \ /* OP(LOGICAL_AND) */ \ /* OP(LOGICAL_NOT) */ \ - /* OP(LOGICAL_OR) */ \ + OP(LOGICAL_OR) \ /* OP(MAX_OP) */ \ /* OP(MIN_OP) */ \ OP(MUL) \ @@ -138,6 +138,7 @@ inline const std::unordered_map {PointwiseAttr::Mode::CMP_GE, 2}, {PointwiseAttr::Mode::CMP_NEQ, 2}, {PointwiseAttr::Mode::DIV, 2}, + {PointwiseAttr::Mode::LOGICAL_OR, 2}, {PointwiseAttr::Mode::MUL, 2}, {PointwiseAttr::Mode::RELU_FWD, 1}, {PointwiseAttr::Mode::SIGMOID_FWD, 1}, diff --git a/include/fusilli/node/pointwise_node.h b/include/fusilli/node/pointwise_node.h index e79df913..392cc8be 100644 --- a/include/fusilli/node/pointwise_node.h +++ b/include/fusilli/node/pointwise_node.h @@ -36,6 +36,7 @@ static const std::unordered_map {PointwiseAttr::Mode::CMP_LE, DataType::Boolean}, {PointwiseAttr::Mode::CMP_GT, DataType::Boolean}, {PointwiseAttr::Mode::CMP_GE, DataType::Boolean}, + {PointwiseAttr::Mode::LOGICAL_OR, DataType::Boolean}, }; class PointwiseNode : public NodeCRTP { diff --git a/include/fusilli/support/asm_emitter.h b/include/fusilli/support/asm_emitter.h index ec67ff1a..7a94c9a3 100644 --- a/include/fusilli/support/asm_emitter.h +++ b/include/fusilli/support/asm_emitter.h @@ -1741,6 +1741,7 @@ inline ErrorOr PointwiseNode::emitNodePreAsm() const { FUSILLI_DECLARE_BINARY_TORCH_EMITTER(CMP_NEQ, torch.aten.ne.Tensor) FUSILLI_DECLARE_BINARY_TORCH_EMITTER(DIV, torch.aten.div.Tensor) FUSILLI_DECLARE_BINARY_TORCH_EMITTER(MUL, torch.aten.mul.Tensor) + FUSILLI_DECLARE_BINARY_TORCH_EMITTER(LOGICAL_OR, torch.aten.logical_or) FUSILLI_DECLARE_UNARY_TORCH_EMITTER(RELU_FWD, torch.aten.relu) FUSILLI_DECLARE_UNARY_TORCH_EMITTER(SIGMOID_FWD, torch.aten.sigmoid) FUSILLI_DECLARE_UNARY_TORCH_EMITTER(TANH_FWD, torch.aten.tanh) diff --git a/samples/pointwise/pointwise_binary_cmp_ops.cpp b/samples/pointwise/pointwise_binary_cmp_ops.cpp index d5c6728c..e77b03d8 100644 --- a/samples/pointwise/pointwise_binary_cmp_ops.cpp +++ b/samples/pointwise/pointwise_binary_cmp_ops.cpp @@ -49,7 +49,8 @@ TEST_CASE("Pointwise binary compare ops", "[pointwise][graph]") { PointwiseAttr::Mode::CMP_LE, PointwiseAttr::Mode::CMP_GT, PointwiseAttr::Mode::CMP_GE, - PointwiseAttr::Mode::CMP_NEQ); + PointwiseAttr::Mode::CMP_NEQ, + PointwiseAttr::Mode::LOGICAL_OR); // clang-format on auto execute = [&](Handle &handle, DataType dt, T x0, T x1) { @@ -132,6 +133,10 @@ TEST_CASE("Pointwise binary compare ops", "[pointwise][graph]") { y = (x0 != x1); break; } + case PointwiseAttr::Mode::LOGICAL_OR: { + y = (x0 != T(0)) || (x1 != T(0)); + break; + } default: FAIL( "Unsupported pointwise mode: " << PointwiseAttr::kModeToStr.at(mode)); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 644e19c0..dd8eb2b5 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -153,6 +153,7 @@ add_fusilli_lit_tests( lit/test_pointwise_asm_emitter_cmp_lt.cpp lit/test_pointwise_asm_emitter_cmp_le.cpp lit/test_pointwise_asm_emitter_div.cpp + lit/test_pointwise_asm_emitter_logical_or.cpp lit/test_pointwise_asm_emitter_mul.cpp lit/test_pointwise_asm_emitter_mul_scalar.cpp lit/test_pointwise_asm_emitter_sigmoid.cpp diff --git a/tests/lit/test_pointwise_asm_emitter_logical_or.cpp b/tests/lit/test_pointwise_asm_emitter_logical_or.cpp new file mode 100644 index 00000000..b4ee11ca --- /dev/null +++ b/tests/lit/test_pointwise_asm_emitter_logical_or.cpp @@ -0,0 +1,66 @@ +// Copyright 2026 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// RUN: %{TEST_EXE} | iree-opt --verify-roundtrip +// RUN: %{TEST_EXE} | FileCheck %s --check-prefix=TORCH-CHECK +// RUN: %{TEST_EXE} stats | FileCheck %s --check-prefix=%{BACKEND}-STATS-CHECK + +// clang-format off +// +// TORCH-CHECK: module @module { +// TORCH-CHECK: func.func @main(%[[RESULT0:.+]]: !torch.tensor<[16,256,64,32],i1>, %[[ARG0:.+]]: !torch.vtensor<[16,256,64,32],f32>, %[[ARG1:.+]]: !torch.vtensor<[1,256,1,1],f32>) attributes {torch.assume_strict_symbolic_shapes} { +// TORCH-CHECK: %[[PERM0_0:.+]] = torch.constant.int 0 +// TORCH-CHECK: %[[PERM0_1:.+]] = torch.constant.int 1 +// TORCH-CHECK: %[[PERM0_2:.+]] = torch.constant.int 2 +// TORCH-CHECK: %[[PERM0_3:.+]] = torch.constant.int 3 +// TORCH-CHECK: %[[PERM0_LIST:.+]] = torch.prim.ListConstruct %[[PERM0_0]], %[[PERM0_1]], %[[PERM0_2]], %[[PERM0_3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// TORCH-CHECK: %[[PERMUTE0:.+]] = torch.aten.permute %[[ARG0]], %[[PERM0_LIST]] : !torch.vtensor<[16,256,64,32],f32>, !torch.list -> !torch.vtensor<[16,256,64,32],f32> +// TORCH-CHECK: %[[PERM1_0:.+]] = torch.constant.int 0 +// TORCH-CHECK: %[[PERM1_1:.+]] = torch.constant.int 1 +// TORCH-CHECK: %[[PERM1_2:.+]] = torch.constant.int 2 +// TORCH-CHECK: %[[PERM1_3:.+]] = torch.constant.int 3 +// TORCH-CHECK: %[[PERM1_LIST:.+]] = torch.prim.ListConstruct %[[PERM1_0]], %[[PERM1_1]], %[[PERM1_2]], %[[PERM1_3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// TORCH-CHECK: %[[PERMUTE1:.+]] = torch.aten.permute %[[ARG1]], %[[PERM1_LIST]] : !torch.vtensor<[1,256,1,1],f32>, !torch.list -> !torch.vtensor<[1,256,1,1],f32> +// TORCH-CHECK: %[[CEIL:.+]] = torch.aten.logical_or %[[PERMUTE0]], %[[PERMUTE1]] : !torch.vtensor<[16,256,64,32],f32>, !torch.vtensor<[1,256,1,1],f32> -> !torch.vtensor<[16,256,64,32],i1> +// TORCH-CHECK: %[[PERM_OUT_0:.+]] = torch.constant.int 0 +// TORCH-CHECK: %[[PERM_OUT_1:.+]] = torch.constant.int 1 +// TORCH-CHECK: %[[PERM_OUT_2:.+]] = torch.constant.int 2 +// TORCH-CHECK: %[[PERM_OUT_3:.+]] = torch.constant.int 3 +// TORCH-CHECK: %[[PERM_OUT_LIST:.+]] = torch.prim.ListConstruct %[[PERM_OUT_0]], %[[PERM_OUT_1]], %[[PERM_OUT_2]], %[[PERM_OUT_3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// TORCH-CHECK: %[[PERM_OUT:.+]] = torch.aten.permute %[[CEIL]], %[[PERM_OUT_LIST]] : !torch.vtensor<[16,256,64,32],i1>, !torch.list -> !torch.vtensor<[16,256,64,32],i1> +// TORCH-CHECK: torch.overwrite.tensor.contents %[[PERM_OUT]] overwrites %[[RESULT0]] : !torch.vtensor<[16,256,64,32],i1>, !torch.tensor<[16,256,64,32],i1> +// TORCH-CHECK: return +// TORCH-CHECK: } +// TORCH-CHECK: } +// +// AMDGPU-STATS-CHECK: "transient-memory-size": 0 +// AMDGPU-STATS-CHECK: "dispatch-count": 1 +// CPU-STATS-CHECK: "transient-memory-size": 0 +// CPU-STATS-CHECK: "dispatch-count": 1 +// +// clang-format on + +#include + +#include "utils.h" + +#include +#include + +using namespace fusilli; + +int main(int argc, char **argv) { + std::string mode = (argc > 1) ? argv[1] : "default"; + + auto status = testBinaryPointwiseAsmEmitter( + "pointwise_asm_emitter_logical_or", "logical_or", mode, + PointwiseAttr::Mode::LOGICAL_OR, {16, 256, 64, 32}, {1, 256, 1, 1}); + if (isError(status)) { + std::cerr << "Test failed: " << status << std::endl; + return 1; + } + return 0; +}