Skip to content

Commit 303e2a0

Browse files
authored
Add boolean support for op_unbind_copy
Differential Revision: D81705165 Pull Request resolved: #13956
1 parent 7a7e939 commit 303e2a0

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

kernels/portable/cpu/op_unbind_copy.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ void unbind_copy_int_out(
5555
ScalarType in_type = input.scalar_type();
5656
ScalarType out_type = out[0].scalar_type();
5757

58-
ET_SWITCH_REALHBF16_TYPES(
58+
ET_SWITCH_REALHBBF16_TYPES(
5959
in_type, ctx, "unbind_copy.int_out", CTYPE_IN, [&]() {
60-
ET_SWITCH_REALHBF16_TYPES(
60+
ET_SWITCH_REALHBBF16_TYPES(
6161
out_type, ctx, "unbind_copy.int_out", CTYPE_OUT, [&]() {
6262
const CTYPE_IN* const input_data =
6363
input.const_data_ptr<CTYPE_IN>();

kernels/test/op_unbind_copy_test.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,3 +374,54 @@ TEST_F(OpUnbindCopyIntOutTest, DynamicShapeUnbound) {
374374
test_dynamic_shape(
375375
{1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
376376
}
377+
378+
TEST_F(OpUnbindCopyIntOutTest, BooleanTensorUnbindDim2) {
379+
// Test case with inputs:
380+
// ArgType.Tensor torch.bool (1, 7, 4)
381+
// ArgType.Dim 2
382+
TensorFactory<ScalarType::Bool> tf;
383+
TensorListFactory<ScalarType::Bool> tlf;
384+
385+
// Create input tensor of shape (1, 7, 4) filled with bool values
386+
Tensor input = tf.zeros({1, 7, 4});
387+
auto in_data = input.mutable_data_ptr<bool>();
388+
389+
// Fill with alternating true/false pattern
390+
for (int i = 0; i < 1 * 7 * 4; i++) {
391+
in_data[i] = (i % 2) == 0;
392+
}
393+
394+
// Unbinding along dimension 2 should produce 4 tensors of shape (1, 7)
395+
int64_t unbind_dim = 2;
396+
int64_t num_outputs = input.size(unbind_dim); // Should be 4
397+
398+
// Create output tensors
399+
std::vector<Tensor> outputs;
400+
for (int i = 0; i < num_outputs; i++) {
401+
outputs.push_back(tf.zeros({1, 7}));
402+
}
403+
TensorList out = tlf.zeros_like(outputs);
404+
405+
// Perform unbind operation - boolean tensors are now supported
406+
op_unbind_copy_int_out(input, unbind_dim, out);
407+
408+
// Verify outputs
409+
for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
410+
EXPECT_EQ(out[output_idx].dim(), 2);
411+
EXPECT_EQ(out[output_idx].size(0), 1);
412+
EXPECT_EQ(out[output_idx].size(1), 7);
413+
414+
auto out_data = out[output_idx].const_data_ptr<bool>();
415+
416+
// Verify the data correctness
417+
for (int i = 0; i < 1; i++) {
418+
for (int j = 0; j < 7; j++) {
419+
int input_idx = i * 7 * 4 + j * 4 + output_idx;
420+
bool expected = (input_idx % 2) == 0;
421+
EXPECT_EQ(out_data[i * 7 + j], expected)
422+
<< "Mismatch at output[" << output_idx << "][" << i << "][" << j
423+
<< "]";
424+
}
425+
}
426+
}
427+
}

0 commit comments

Comments
 (0)