Skip to content

Commit 97d6d49

Browse files
authored
[compute/cker] Reduce duplicated code and Fix wrong shape size. (#15158)
* [compute/cker] Reduce duplicated code and fix wrong shape size. This commit refactors dims storage initialization of Shape. - Reduce duplicated code - Fix wrong shape size ONE-DCO-1.0-Signed-off-by: ragmani <ragmani0216@gmail.com> * - Reverted the incorrectly replaced resize and fixed some flawed initializations. - Add Shape unittests. ONE-DCO-1.0-Signed-off-by: ragmani <ragmani0216@gmail.com>
1 parent 41228ec commit 97d6d49

File tree

3 files changed

+383
-49
lines changed

3 files changed

+383
-49
lines changed

runtime/compute/cker/include/cker/Shape.h

Lines changed: 19 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -47,49 +47,22 @@ class Shape
4747
// Constructor that takes a dimension count.
4848
// If dimensions_count <= kMaxSmallSize, it uses a fixed-size array.
4949
// Otherwise, it uses a dynamic vector.
50-
explicit Shape(int dimensions_count) : _size(dimensions_count)
51-
{
52-
if (dimensions_count <= kMaxSmallSize)
53-
{
54-
dims_ = std::array<int32_t, kMaxSmallSize>{};
55-
}
56-
else
57-
{
58-
dims_ = std::vector<int32_t>(dimensions_count);
59-
}
60-
}
50+
explicit Shape(int dimensions_count) : _size(dimensions_count) { initStorage(dimensions_count); }
6151

6252
// Constructor that creates a shape of given size and fills all dimensions with "value".
63-
Shape(int shape_size, int32_t value) : _size(0)
53+
Shape(int shape_size, int32_t value) : _size(shape_size)
6454
{
65-
if (shape_size <= kMaxSmallSize)
66-
{
67-
dims_ = std::array<int32_t, kMaxSmallSize>{};
68-
}
69-
else
70-
{
71-
dims_ = std::vector<int32_t>(shape_size);
72-
}
73-
55+
initStorage(shape_size);
7456
for (int i = 0; i < shape_size; ++i)
7557
{
7658
SetDim(i, value);
7759
}
7860
}
7961

8062
// Constructor that creates a shape from an array of dimension data.
81-
Shape(int dimensions_count, const int32_t *dims_data) : _size(0)
63+
Shape(int dimensions_count, const int32_t *dims_data) : _size(dimensions_count)
8264
{
83-
// Explicitly initialize dims_ based on dimensions_count to avoid uninitialized state.
84-
if (dimensions_count <= kMaxSmallSize)
85-
{
86-
dims_ = std::array<int32_t, kMaxSmallSize>{};
87-
}
88-
else
89-
{
90-
dims_ = std::vector<int32_t>(dimensions_count);
91-
}
92-
65+
initStorage(dimensions_count);
9366
ReplaceWith(dimensions_count, dims_data);
9467
}
9568

@@ -98,18 +71,7 @@ class Shape
9871
Shape(const std::initializer_list<int> init_list) : _size(0)
9972
{
10073
const auto size = static_cast<int>(std::distance(init_list.begin(), init_list.end()));
101-
102-
// Explicitly initialize dims_ based on the initializer list size to prevent
103-
// "maybe uninitialized" warnings when BuildFrom() is invoked.
104-
if (size <= kMaxSmallSize)
105-
{
106-
dims_ = std::array<int32_t, kMaxSmallSize>{};
107-
}
108-
else
109-
{
110-
dims_ = std::vector<int32_t>(size);
111-
}
112-
74+
initStorage(size);
11375
BuildFrom(init_list);
11476
}
11577

@@ -207,10 +169,7 @@ class Shape
207169
// initialize dims_ explicitly based on dimensions_count to ensure it is in a valid state.
208170
if (dims_.valueless_by_exception())
209171
{
210-
if (dimensions_count <= kMaxSmallSize)
211-
dims_ = std::array<int32_t, kMaxSmallSize>{};
212-
else
213-
dims_ = std::vector<int32_t>(dimensions_count);
172+
initStorage(dimensions_count);
214173
}
215174

216175
std::vector<int32_t> oldDims;
@@ -246,6 +205,7 @@ class Shape
246205
// Replaces the current shape with a new one defined by dimensions_count and dims_data.
247206
inline void ReplaceWith(int dimensions_count, const int32_t *dims_data)
248207
{
208+
assert(dims_data != nullptr);
249209
Resize(dimensions_count);
250210
std::memcpy(DimsData(), dims_data, dimensions_count * sizeof(int32_t));
251211
}
@@ -304,10 +264,20 @@ class Shape
304264
bool operator!=(const Shape &comp) const { return !((*this) == comp); }
305265

306266
private:
267+
// Helper function: initialize dims_ storage based on the number of dimensions.
268+
inline void initStorage(int dimensions_count)
269+
{
270+
assert(dimensions_count >= 0);
271+
if (dimensions_count <= kMaxSmallSize)
272+
dims_ = std::array<int32_t, kMaxSmallSize>{};
273+
else
274+
dims_ = std::vector<int32_t>(dimensions_count);
275+
}
276+
307277
// For use only by ExtendedShape(), written to guarantee (return-value) copy
308278
// elision in C++17.
309279
// This creates a shape padded to the desired size with the specified value.
310-
Shape(int new_shape_size, const Shape &shape, int pad_value) : _size(0)
280+
Shape(int new_shape_size, const Shape &shape, int pad_value) : _size(new_shape_size)
311281
{
312282
assert(new_shape_size >= shape.DimensionsCount());
313283
assert(new_shape_size <= kMaxSmallSize);
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#ifndef __DEATH_TEST_MACROS_H__
18+
#define __DEATH_TEST_MACROS_H__
19+
20+
#include <gtest/gtest.h>
21+
22+
// In release mode, assertions might not trigger abort() via assert(),
23+
// so we use EXPECT_EXIT and require that the statement exits with EXIT_FAILURE.
24+
// In debug mode, we use EXPECT_DEATH as usual.
25+
#ifdef NDEBUG
26+
#define EXPECT_EXIT_BY_ABRT_DEBUG_ONLY(statement, regex)
27+
#else
28+
#define EXPECT_EXIT_BY_ABRT_DEBUG_ONLY(statement, regex) \
29+
EXPECT_EXIT(statement, ::testing::KilledBySignal(SIGABRT), regex)
30+
#endif
31+
32+
#endif // __DEATH_TEST_MACROS_H__

0 commit comments

Comments
 (0)