Skip to content

Commit b983fe2

Browse files
committed
fix
1 parent f96de58 commit b983fe2

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

src/layers/SplitLayer.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#include "layers/SplitLayer.hpp"
22

3+
#include <algorithm>
4+
#include <cstring>
5+
36
namespace it_lab_ai {
47

58
void SplitLayer::run(const Tensor& input, Tensor& output) { output = input; }
@@ -51,26 +54,25 @@ void SplitLayer::split_impl(const Tensor& input,
5154
}
5255

5356
const size_t input_axis_stride = shape[axis] * inner_size;
54-
const size_t input_inner_stride = inner_size;
5557

5658
outputs.clear();
5759
outputs.reserve(part_sizes.size());
5860

5961
size_t input_offset = 0;
6062
for (size_t part = 0; part < part_sizes.size(); ++part) {
61-
const size_t output_axis_size = part_sizes[part];
63+
const size_t output_axis_size = static_cast<size_t>(part_sizes[part]);
6264

6365
std::vector<size_t> output_shape_vec(shape.dims());
6466
for (size_t i = 0; i < shape.dims(); ++i) {
65-
output_shape_vec[i] = (i == axis) ? output_axis_size : shape[i];
67+
output_shape_vec[i] =
68+
(static_cast<int>(i) == axis) ? output_axis_size : shape[i];
6669
}
6770
Shape output_shape(output_shape_vec);
6871

6972
outputs.emplace_back(output_shape, input.get_type());
7073
auto& output_data = *outputs.back().as<T>();
7174

7275
const size_t output_part_size = output_axis_size * inner_size;
73-
const size_t input_part_size = output_part_size;
7476

7577
for (size_t outer = 0; outer < outer_size; ++outer) {
7678
const T* input_start =
@@ -106,12 +108,6 @@ void SplitLayer::validate(const Tensor& input) const {
106108
if (*num_outputs_ <= 0) {
107109
throw std::runtime_error("num_outputs must be positive");
108110
}
109-
if (*num_outputs_ > axis_size) {
110-
throw std::runtime_error("num_outputs cannot be greater than axis size");
111-
}
112-
}
113-
114-
if (!splits_ && num_outputs_) {
115111
if (*num_outputs_ > axis_size) {
116112
throw std::runtime_error("num_outputs (" + std::to_string(*num_outputs_) +
117113
") cannot be greater than axis size (" +

0 commit comments

Comments
 (0)