Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 7a49008

Browse files
author
DominikaJedynak
authored
[BUGFIX] Type fix for large tensors (#20922)
* Dimension type fix * Review suggestion
1 parent 588f541 commit 7a49008

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

src/operator/subgraph/dnnl/dnnl_conv.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,11 @@ static void UpdateConvWeightBias(NDArray* weight,
5959
const float* var_ptr = variance.data().dptr<float>();
6060
DType* update_weight_ptr = update_weight.data().dptr<DType>();
6161
DType* update_bias_ptr = update_bias.data().dptr<DType>();
62-
size_t channel = gamma.shape()[0];
62+
index_t channel = static_cast<index_t>(gamma.shape()[0]);
6363
const auto wshape = weight->shape();
6464
size_t offset = wshape.ProdShape(1, wshape.ndim());
6565
#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
66-
for (int c = 0; c < static_cast<int>(channel); ++c) {
66+
for (index_t c = 0; c < channel; ++c) {
6767
const DType* p1 = weight_ptr + c * offset;
6868
DType* p2 = update_weight_ptr + c * offset;
6969
float alpha = (param->fix_gamma ? 1.0f : gamma_ptr[c]) / sqrt(var_ptr[c] + param->eps);

src/operator/subgraph/dnnl/dnnl_fc.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -235,15 +235,15 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
235235
const mxnet::TShape oshape = output.shape();
236236
dnnl::memory::dims out_dims(2);
237237
if (oshape.ndim() == 2) {
238-
out_dims[0] = static_cast<int>(oshape[0]);
239-
out_dims[1] = static_cast<int>(oshape[1]);
238+
out_dims[0] = static_cast<index_t>(oshape[0]);
239+
out_dims[1] = static_cast<index_t>(oshape[1]);
240240
} else {
241241
if (!default_param.flatten) {
242-
out_dims[0] = static_cast<int>(oshape.ProdShape(0, oshape.ndim() - 1));
243-
out_dims[1] = static_cast<int>(oshape[oshape.ndim() - 1]);
242+
out_dims[0] = static_cast<index_t>(oshape.ProdShape(0, oshape.ndim() - 1));
243+
out_dims[1] = static_cast<index_t>(oshape[oshape.ndim() - 1]);
244244
} else {
245-
out_dims[0] = static_cast<int>(static_cast<int>(oshape[0]));
246-
out_dims[1] = static_cast<int>(oshape.ProdShape(1, oshape.ndim()));
245+
out_dims[0] = static_cast<index_t>(oshape[0]);
246+
out_dims[1] = static_cast<index_t>(oshape.ProdShape(1, oshape.ndim()));
247247
}
248248
}
249249
dnnl::memory::desc out_md =

0 commit comments

Comments
 (0)