Skip to content

Commit 8b3326e

Browse files
mauriciocm9Mauricio Cortazar
andauthored
Add support for bool type in SplitToSequence (microsoft#24929)
### Description Add support for `bool` type to address the issue below. ### Motivation and Context This PR fixes microsoft#12286 Co-authored-by: Mauricio Cortazar <[email protected]>
1 parent 3426f64 commit 8b3326e

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

docs/OperatorKernels.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ Do not modify directly.*
437437
|||[13, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
438438
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
439439
|||[2, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
440-
|SplitToSequence|*in* input:**T**<br> *in* split:**I**<br> *out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))<br/> **T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(string)|
440+
|SplitToSequence|*in* input:**T**<br> *in* split:**I**<br> *out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))<br/> **T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(string)|
441441
|Sqrt|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float)|
442442
|||[6, 12]|**T** = tensor(double), tensor(float)|
443443
|Squeeze|*in* data:**T**<br> *in* axes:**tensor(int64)**<br> *out* squeezed:**T**<br><br>or<br><br>*in* data:**T**<br> *out* squeezed:**T**|23+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|

onnxruntime/core/providers/cpu/sequence/sequence_ops.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ ONNX_CPU_OPERATOR_KERNEL(
339339
11,
340340
KernelDefBuilder()
341341
.TypeConstraint("T",
342-
BuildKernelDefConstraints<float, MLFloat16, double, int32_t, int64_t, std::string>())
342+
BuildKernelDefConstraints<float, MLFloat16, double, int32_t, int64_t, bool, std::string>())
343343
.TypeConstraint("S", DataTypeImpl::AllSequenceTensorTypes())
344344
.TypeConstraint("I", BuildKernelDefConstraints<int32_t, int64_t>()),
345345
SplitToSequence);

onnxruntime/test/providers/cpu/sequence/sequence_ops_test.cc

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,5 +536,19 @@ TEST(SequenceOpsTest, SplitToSequence_PositiveAxisDontKeepDims) {
536536
test.AddSeqOutput("S2", output);
537537
test.Run();
538538
}
539+
540+
TEST(SequenceOpsTest, SplitToSequence_BoolSplit) {
541+
OpTester test("SplitToSequence", 11);
542+
test.AddInput<bool>("input", {4, 2}, std::initializer_list<bool>({1, 1, 1, 1, 0, 0, 0, 0}));
543+
int64_t axis = 0;
544+
test.AddAttribute("axis", axis);
545+
SeqTensors<bool> output;
546+
output.AddTensor({1, 2}, {1, 1});
547+
output.AddTensor({1, 2}, {1, 1});
548+
output.AddTensor({1, 2}, {0, 0});
549+
output.AddTensor({1, 2}, {0, 0});
550+
test.AddSeqOutput("S2", output);
551+
test.Run();
552+
}
539553
} // namespace test
540-
} // namespace onnxruntime
554+
} // namespace onnxruntime

0 commit comments

Comments
 (0)