|
1 | 1 | #include "layers/SplitLayer.hpp" |
2 | 2 |
|
| 3 | +#include <algorithm> |
| 4 | +#include <cstring> |
| 5 | + |
3 | 6 | namespace it_lab_ai { |
4 | 7 |
|
5 | 8 | void SplitLayer::run(const Tensor& input, Tensor& output) { output = input; } |
@@ -51,26 +54,25 @@ void SplitLayer::split_impl(const Tensor& input, |
51 | 54 | } |
52 | 55 |
|
53 | 56 | const size_t input_axis_stride = shape[axis] * inner_size; |
54 | | - const size_t input_inner_stride = inner_size; |
55 | 57 |
|
56 | 58 | outputs.clear(); |
57 | 59 | outputs.reserve(part_sizes.size()); |
58 | 60 |
|
59 | 61 | size_t input_offset = 0; |
60 | 62 | for (size_t part = 0; part < part_sizes.size(); ++part) { |
61 | | - const size_t output_axis_size = part_sizes[part]; |
| 63 | + const size_t output_axis_size = static_cast<size_t>(part_sizes[part]); |
62 | 64 |
|
63 | 65 | std::vector<size_t> output_shape_vec(shape.dims()); |
64 | 66 | for (size_t i = 0; i < shape.dims(); ++i) { |
65 | | - output_shape_vec[i] = (i == axis) ? output_axis_size : shape[i]; |
| 67 | + output_shape_vec[i] = |
| 68 | + (static_cast<int>(i) == axis) ? output_axis_size : shape[i]; |
66 | 69 | } |
67 | 70 | Shape output_shape(output_shape_vec); |
68 | 71 |
|
69 | 72 | outputs.emplace_back(output_shape, input.get_type()); |
70 | 73 | auto& output_data = *outputs.back().as<T>(); |
71 | 74 |
|
72 | 75 | const size_t output_part_size = output_axis_size * inner_size; |
73 | | - const size_t input_part_size = output_part_size; |
74 | 76 |
|
75 | 77 | for (size_t outer = 0; outer < outer_size; ++outer) { |
76 | 78 | const T* input_start = |
@@ -106,12 +108,6 @@ void SplitLayer::validate(const Tensor& input) const { |
106 | 108 | if (*num_outputs_ <= 0) { |
107 | 109 | throw std::runtime_error("num_outputs must be positive"); |
108 | 110 | } |
109 | | - if (*num_outputs_ > axis_size) { |
110 | | - throw std::runtime_error("num_outputs cannot be greater than axis size"); |
111 | | - } |
112 | | - } |
113 | | - |
114 | | - if (!splits_ && num_outputs_) { |
115 | 111 | if (*num_outputs_ > axis_size) { |
116 | 112 | throw std::runtime_error("num_outputs (" + std::to_string(*num_outputs_) + |
117 | 113 | ") cannot be greater than axis size (" + |
|
0 commit comments