Skip to content

Commit dd1fe7c

Browse files
cyyeverpytorchmergebot
authored andcommitted
Remove clang-tidy type conversion suppressions (pytorch#166398)
This PR fixes and removes type conversion suppressions of clang-tidy. Pull Request resolved: pytorch#166398 Approved by: https://github.com/Skylion007
1 parent 695cb0d commit dd1fe7c

File tree

7 files changed

+34
-56
lines changed

7 files changed

+34
-56
lines changed

aten/src/ATen/native/ConvolutionTBC.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ Tensor conv_tbc(const Tensor& self, const Tensor& weight, const Tensor& bias, in
5252
for (const auto k : c10::irange(kw)) {
5353
int iShift = std::max(0, static_cast<int>(k - real_pad));
5454
int oShift = std::max(0, static_cast<int>(real_pad - k));
55-
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
56-
int t = std::min(ilen + real_pad - k, olen) - oShift;
55+
long t = std::min(ilen + real_pad - k, olen) - oShift;
5756
// Note: gemm assumes column-major matrices
5857
// input is l*m (row-major)
5958
// weight is m*r (row-major)

aten/src/ATen/native/IndexingUtils.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ bool canUse32BitIndexMath(const TensorBase& t, int64_t max_elem) {
1616
auto linearId = elements - 1;
1717

1818
// NOTE: Assumes all strides are positive, which is true for now
19-
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
20-
for (int i = t.dim() - 1; i >= 0; --i) {
19+
for (auto i = t.dim() - 1; i >= 0; --i) {
2120
auto curDimIndex = linearId % t.sym_size(i);
2221
auto curDimOffset = curDimIndex * t.sym_stride(i);
2322
offset += curDimOffset;

aten/src/ATen/native/QuantizedLinear.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ Tensor fbgemm_linear_int8_weight_fp32_activation(
6868
const float* input_ptr = input_contig.const_data_ptr<float>();
6969

7070
TORCH_CHECK(input.dim() >= 2);
71-
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
7271
const int64_t M = size_to_dim_(input.dim() - 1, input.sizes());
7372
const int64_t K = input.size(input.dim() - 1);
7473
TORCH_CHECK(weight.dim() == 2);

aten/src/ATen/native/cpu/DistanceOpsKernel.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,9 @@ struct Dist {
160160
// value of k.
161161
parallel_for(0, combs, internal::GRAIN_SIZE / (16 * m), [p, self_start, self_end, n, m, res_start](int64_t k, int64_t end) {
162162
const Vec pvec(p);
163-
double n2 = n - .5;
163+
double n2 = static_cast<double>(n) - .5;
164164
// The -1 accounts for floating point truncation issues
165-
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
166-
int64_t i = static_cast<int64_t>((n2 - std::sqrt(n2 * n2 - 2 * k - 1)));
165+
int64_t i = static_cast<int64_t>((n2 - std::sqrt(n2 * n2 - 2.0 * static_cast<double>(k) - 1.0)));
167166
int64_t j = k - n * i + i * (i + 1) / 2 + i + 1;
168167

169168
const scalar_t * self_i = self_start + i * m;

aten/src/ATen/native/quantized/cpu/UpSampleBilinear2d.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,7 @@ void upsample_bilinear2d_out_frame(
7373
const auto rwidth = area_pixel_compute_scale<float>(
7474
input_width, output_width, align_corners, scales_w);
7575

76-
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
77-
float output_scale = output.q_scale() / input.q_scale();
76+
float output_scale = static_cast<float>(output.q_scale() / input.q_scale());
7877

7978
const int64_t input_q_zero_point = input.q_zero_point();
8079
const int64_t output_q_zero_point = output.q_zero_point();

aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp

Lines changed: 29 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ Tensor qcat_nhwc_kernel(
148148
// Vectorized loop
149149
if (c + VLEN <= curr_C) {
150150
auto curr_scale_vec = Vectorized<float>(curr_scale);
151-
auto curr_zero_pt_vec = Vectorized<float>((float)curr_zero_pt);
151+
auto curr_zero_pt_vec = Vectorized<float>(curr_zero_pt);
152152
auto scale_neg_zp_premul = curr_scale_vec * curr_zero_pt_vec.neg();
153153
for (; c + VLEN <= curr_C; c += VLEN) {
154154
auto inp_vec = Vec::loadu(iptr + c);
@@ -174,7 +174,7 @@ Tensor qcat_nhwc_kernel(
174174
int64_t elem_size = curr_C - c;
175175
if ((VLEN == 4 * kVLEN) && elem_size >= kVLEN) {
176176
auto curr_scale_vec = Vectorized<float>(curr_scale);
177-
auto curr_zero_pt_vec = Vectorized<float>((float)curr_zero_pt);
177+
auto curr_zero_pt_vec = Vectorized<float>(curr_zero_pt);
178178
auto scale_neg_zp_premul = curr_scale_vec * curr_zero_pt_vec.neg();
179179
int64_t vec_num = elem_size / kVLEN;
180180
std::array<typename scalar_t::underlying, VLEN> buf_in{};
@@ -611,12 +611,10 @@ void qrelu_kernel(const Tensor& qx, Tensor& qy) {
611611
void leaky_qrelu_out_kernel(Tensor& out, const Tensor& qx,
612612
const Scalar& negval_) {
613613
int64_t i_zp = qx.q_zero_point();
614-
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
615-
float i_scale = qx.q_scale();
614+
float i_scale = static_cast<float>(qx.q_scale());
616615

617616
int64_t o_zp = out.q_zero_point();
618-
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
619-
float o_scale = out.q_scale();
617+
float o_scale = static_cast<float>(out.q_scale());
620618
float o_inv_scale = 1.0f / o_scale;
621619

622620
float negval = negval_.to<float>();
@@ -627,8 +625,8 @@ void leaky_qrelu_out_kernel(Tensor& out, const Tensor& qx,
627625
Vec zero_vec = Vec(0.0f);
628626
Vec one_vec = Vec(1.0f);
629627

630-
Vec i_scale_vec = Vec((float)i_scale);
631-
Vec i_zp_vec = Vec((float)i_zp);
628+
Vec i_scale_vec = Vec(i_scale);
629+
Vec i_zp_vec = Vec(i_zp);
632630
Vec i_scale_zp_neg_premul_vec = i_scale_vec * i_zp_vec.neg();
633631

634632
Vec negval_vec = Vec(negval);
@@ -738,10 +736,9 @@ void qprelu_out_kernel(Tensor& out,
738736

739737
void qgelu_kernel(const Tensor& qx, Tensor& qy, GeluType approximate) {
740738
int64_t zero_point = qx.q_zero_point();
741-
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
742-
float scale = qx.q_scale();
739+
float scale = static_cast<float>(qx.q_scale());
743740
auto scale_vec = Vectorized<float>(scale);
744-
auto zero_point_vec = Vectorized<float>((float)zero_point);
741+
auto zero_point_vec = Vectorized<float>(zero_point);
745742
auto scale_neg_zp_premul_vec = scale_vec * zero_point_vec.neg();
746743
int64_t output_zero_point = zero_point;
747744
float output_scale = scale;
@@ -828,10 +825,9 @@ void qgelu_kernel(const Tensor& qx, Tensor& qy, GeluType approximate) {
828825
void qsigmoid_kernel(
829826
const Tensor& qx, Tensor& qy, double output_scale, int64_t output_zero_point ) {
830827
int64_t zero_point = qx.q_zero_point();
831-
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
832-
float scale = qx.q_scale();
828+
float scale = static_cast<float>(qx.q_scale());
833829
auto scale_vec = Vectorized<float>(scale);
834-
auto zero_point_vec = Vectorized<float>((float)zero_point);
830+
auto zero_point_vec = Vectorized<float>(zero_point);
835831

836832
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qsigmoid", [&]() {
837833
float inv_output_scale = 1.0 / output_scale;
@@ -870,10 +866,9 @@ void qsigmoid_kernel(
870866

871867
void qhardsigmoid_kernel(const Tensor& qx, Tensor& qy) {
872868
int64_t zero_point = qx.q_zero_point();
873-
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
874-
float scale = qx.q_scale();
869+
float scale = static_cast<float>(qx.q_scale());
875870
auto scale_vec = Vectorized<float>(scale);
876-
auto zero_point_vec = Vectorized<float>((float)zero_point);
871+
auto zero_point_vec = Vectorized<float>(zero_point);
877872
auto scale_neg_zp_premul_vec = scale_vec * zero_point_vec.neg();
878873

879874
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qhardsigmoid", [&]() {
@@ -1029,13 +1024,10 @@ void qthreshold_kernel(
10291024

10301025
// defines input and output scales and zero_points
10311026
int64_t input_zero_point = qx.q_zero_point();
1032-
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1033-
float input_scale = qx.q_scale();
1027+
float input_scale = static_cast<float>(qx.q_scale());
10341028
int64_t output_zero_point = qy.q_zero_point();
1035-
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1036-
float output_scale = qy.q_scale();
1037-
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1038-
float inv_output_scale = 1.0 / output_scale;
1029+
float output_scale = static_cast<float>(qy.q_scale());
1030+
float inv_output_scale = static_cast<float>(1.0 / output_scale);
10391031

10401032
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qthreshold", [&]() {
10411033
qy = at::_empty_affine_quantized(
@@ -1096,8 +1088,7 @@ void qhardswish_kernel(const Tensor& qx, Tensor& qy) {
10961088

10971089
const auto o_scale = qy.q_scale();
10981090
const auto o_zero_point = qy.q_zero_point();
1099-
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1100-
const float o_inv_scale = 1.0 / o_scale;
1091+
const float o_inv_scale = static_cast<float>(1.0 / o_scale);
11011092

11021093
using fVec = Vectorized<float>;
11031094
fVec i_scale_vec(i_scale);
@@ -1135,10 +1126,9 @@ void qhardswish_kernel(const Tensor& qx, Tensor& qy) {
11351126

11361127
void qtanh_kernel(const Tensor& qx, Tensor& qy) {
11371128
int64_t zero_point = qx.q_zero_point();
1138-
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1139-
float scale = qx.q_scale();
1129+
float scale = static_cast<float>(qx.q_scale());
11401130
auto scale_vec = Vectorized<float>(scale);
1141-
auto zero_point_vec = Vectorized<float>((float)zero_point);
1131+
auto zero_point_vec = Vectorized<float>(zero_point);
11421132
auto scale_neg_zp_premul_vec = scale_vec * zero_point_vec.neg();
11431133

11441134
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qtanh", [&]() {
@@ -1198,16 +1188,13 @@ void qelu_kernel(
11981188
// they are NOT related to the quantization scale term
11991189

12001190
int64_t i_zp = qx.q_zero_point();
1201-
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1202-
float i_scale = qx.q_scale();
1191+
float i_scale = static_cast<float>(qx.q_scale());
12031192

12041193
// In a future PR, we can improve on output scale and zero_point
12051194
// selection.
12061195
int64_t o_zp = qy.q_zero_point();
1207-
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1208-
float o_scale = qy.q_scale();
1209-
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1210-
float inv_o_scale = 1.0 / o_scale;
1196+
float o_scale = static_cast<float>(qy.q_scale());
1197+
float inv_o_scale = static_cast<float>(1.0 / o_scale);
12111198

12121199
float alpha_float = alpha.to<float>();
12131200
float scale_coef = scale.to<float>();
@@ -1227,7 +1214,7 @@ void qelu_kernel(
12271214
Vec scale_coef_vec = Vec(scale_coef);
12281215
Vec input_scale_coef_vec = Vec(input_scale_coef);
12291216
Vec i_scale_vec = Vec(i_scale);
1230-
Vec i_zero_point_vec = Vec((float)i_zp);
1217+
Vec i_zero_point_vec = Vec(i_zp);
12311218
Vec i_scale_neg_zp_premul_vec = i_scale_vec * i_zero_point_vec.neg();
12321219

12331220
cpu_kernel_vec(
@@ -1326,23 +1313,20 @@ void qadd_scalar_kernel(Tensor& out, const Tensor& self, const Scalar& other) {
13261313
template <bool ReLUFused = false>
13271314
void qadd_kernel(Tensor& out, const Tensor& self, const Tensor& other) {
13281315
int64_t zero_point = out.q_zero_point();
1329-
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1330-
float scale = out.q_scale();
1316+
float scale = static_cast<float>(out.q_scale());
13311317
float inv_scale = 1.0f / scale;
13321318
int64_t self_zero_point = self.q_zero_point();
1333-
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1334-
float self_scale = self.q_scale();
1319+
float self_scale = static_cast<float>(self.q_scale());
13351320
int64_t other_zero_point = other.q_zero_point();
1336-
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1337-
float other_scale = other.q_scale();
1321+
float other_scale = static_cast<float>(other.q_scale());
13381322

13391323
// Broadcast out the parameters here to amortize out that cost across
13401324
// loop iterations.
13411325
// TODO: we can optimize dequantization by doing a premultiplication
13421326
// of the zero point by scale and doing FMA on scale*x_q - (scale*zero_point)
1343-
auto self_zero_point_vec = Vectorized<float>((float)self_zero_point);
1327+
auto self_zero_point_vec = Vectorized<float>(self_zero_point);
13441328
auto self_scale_vec = Vectorized<float>(self_scale);
1345-
auto other_zero_point_vec = Vectorized<float>((float)other_zero_point);
1329+
auto other_zero_point_vec = Vectorized<float>(other_zero_point);
13461330
auto other_scale_vec = Vectorized<float>(other_scale);
13471331

13481332
auto self_scale_neg_zp_premul_vec = self_scale_vec * self_zero_point_vec.neg();
@@ -2965,7 +2949,7 @@ void quantized_normalize_kernel(
29652949
const bool beta_null = beta_data == nullptr;
29662950
int64_t x_zp = X.q_zero_point();
29672951
float x_scale = X.q_scale();
2968-
fVec x_zp_vec((float)x_zp);
2952+
fVec x_zp_vec(x_zp);
29692953
fVec one_vec(1.0f);
29702954
fVec zero_vec(0.0f);
29712955
float x_fake_scale = 1.0f;
@@ -3253,7 +3237,7 @@ void quantized_groupnorm_nhwc_kernel(
32533237
const bool beta_null = beta_data == nullptr;
32543238
int64_t x_zp = X.q_zero_point();
32553239
float x_scale = X.q_scale();
3256-
fVec x_zp_vec((float)x_zp);
3240+
fVec x_zp_vec(x_zp);
32573241
fVec one_vec(1.0f);
32583242
fVec zero_vec(0.0f);
32593243
float x_fake_scale = 1.0f;

aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,6 @@ at::Tensor& PackedLinearWeightFp16::apply_dynamic_impl(
414414
TORCH_CHECK(input.size(input.dim() - 1) == packed_weight_fp16.numRows())
415415
TORCH_CHECK(input.dim() >= 2);
416416

417-
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
418417
const int64_t M = size_to_dim_(input.dim() - 1, input.sizes());
419418
const int64_t N = packed_weight_fp16.numCols();
420419
std::vector<int64_t> output_sizes = input.sizes().vec();

0 commit comments

Comments
 (0)