Skip to content

Commit 0b95d33

Browse files
co63ocmaxiaolong001
authored andcommitted
rename ctx to dev_ctx,xpu_ctx (PaddlePaddle#74513)
1 parent 7871f62 commit 0b95d33

19 files changed

+138
-136
lines changed

paddle/phi/kernels/legacy/compare_kernel.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,42 +19,42 @@ limitations under the License. */
1919
namespace phi {
2020

2121
template <typename T, typename Context>
22-
void LessThanRawKernel(const Context& ctx,
22+
void LessThanRawKernel(const Context& dev_ctx,
2323
const DenseTensor& x,
2424
const DenseTensor& y,
2525
int axis,
2626
DenseTensor* out);
2727

2828
template <typename T, typename Context>
29-
void LessEqualRawKernel(const Context& ctx,
29+
void LessEqualRawKernel(const Context& dev_ctx,
3030
const DenseTensor& x,
3131
const DenseTensor& y,
3232
int axis,
3333
DenseTensor* out);
3434

3535
template <typename T, typename Context>
36-
void GreaterThanRawKernel(const Context& ctx,
36+
void GreaterThanRawKernel(const Context& dev_ctx,
3737
const DenseTensor& x,
3838
const DenseTensor& y,
3939
int axis,
4040
DenseTensor* out);
4141

4242
template <typename T, typename Context>
43-
void GreaterEqualRawKernel(const Context& ctx,
43+
void GreaterEqualRawKernel(const Context& dev_ctx,
4444
const DenseTensor& x,
4545
const DenseTensor& y,
4646
int axis,
4747
DenseTensor* out);
4848

4949
template <typename T, typename Context>
50-
void EqualRawKernel(const Context& ctx,
50+
void EqualRawKernel(const Context& dev_ctx,
5151
const DenseTensor& x,
5252
const DenseTensor& y,
5353
int axis,
5454
DenseTensor* out);
5555

5656
template <typename T, typename Context>
57-
void NotEqualRawKernel(const Context& ctx,
57+
void NotEqualRawKernel(const Context& dev_ctx,
5858
const DenseTensor& x,
5959
const DenseTensor& y,
6060
int axis,

paddle/phi/kernels/legacy/cpu/compare_kernel.cc

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,88 +25,88 @@ template <typename T,
2525
typename Context,
2626
typename Functor,
2727
typename InverseFunctor>
28-
inline void CompareRawKernelImpl(const Context& ctx,
28+
inline void CompareRawKernelImpl(const Context& dev_ctx,
2929
const DenseTensor& x,
3030
const DenseTensor& y,
3131
int axis,
3232
DenseTensor* out) {
33-
ctx.template Alloc<bool>(out);
33+
dev_ctx.template Alloc<bool>(out);
3434
if (x.dims().size() >= y.dims().size()) {
3535
funcs::ElementwiseCompute<Functor, T, bool>(
36-
ctx, x, y, Functor(), out, axis);
36+
dev_ctx, x, y, Functor(), out, axis);
3737
} else {
3838
funcs::ElementwiseCompute<InverseFunctor, T, bool>(
39-
ctx, x, y, InverseFunctor(), out, axis);
39+
dev_ctx, x, y, InverseFunctor(), out, axis);
4040
}
4141
}
4242

4343
template <typename T, typename Context>
44-
void LessThanRawKernel(const Context& ctx,
44+
void LessThanRawKernel(const Context& dev_ctx,
4545
const DenseTensor& x,
4646
const DenseTensor& y,
4747
int axis,
4848
DenseTensor* out) {
4949
CompareRawKernelImpl<T,
5050
Context,
5151
funcs::LessThanFunctor<T>,
52-
funcs::GreaterThanFunctor<T>>(ctx, x, y, axis, out);
52+
funcs::GreaterThanFunctor<T>>(dev_ctx, x, y, axis, out);
5353
}
5454

5555
template <typename T, typename Context>
56-
void LessEqualRawKernel(const Context& ctx,
56+
void LessEqualRawKernel(const Context& dev_ctx,
5757
const DenseTensor& x,
5858
const DenseTensor& y,
5959
int axis,
6060
DenseTensor* out) {
6161
CompareRawKernelImpl<T,
6262
Context,
6363
funcs::LessEqualFunctor<T>,
64-
funcs::GreaterEqualFunctor<T>>(ctx, x, y, axis, out);
64+
funcs::GreaterEqualFunctor<T>>(dev_ctx, x, y, axis, out);
6565
}
6666

6767
template <typename T, typename Context>
68-
void GreaterThanRawKernel(const Context& ctx,
68+
void GreaterThanRawKernel(const Context& dev_ctx,
6969
const DenseTensor& x,
7070
const DenseTensor& y,
7171
int axis,
7272
DenseTensor* out) {
7373
CompareRawKernelImpl<T,
7474
Context,
7575
funcs::GreaterThanFunctor<T>,
76-
funcs::LessThanFunctor<T>>(ctx, x, y, axis, out);
76+
funcs::LessThanFunctor<T>>(dev_ctx, x, y, axis, out);
7777
}
7878
template <typename T, typename Context>
79-
void GreaterEqualRawKernel(const Context& ctx,
79+
void GreaterEqualRawKernel(const Context& dev_ctx,
8080
const DenseTensor& x,
8181
const DenseTensor& y,
8282
int axis,
8383
DenseTensor* out) {
8484
CompareRawKernelImpl<T,
8585
Context,
8686
funcs::GreaterEqualFunctor<T>,
87-
funcs::LessEqualFunctor<T>>(ctx, x, y, axis, out);
87+
funcs::LessEqualFunctor<T>>(dev_ctx, x, y, axis, out);
8888
}
8989
template <typename T, typename Context>
90-
void EqualRawKernel(const Context& ctx,
90+
void EqualRawKernel(const Context& dev_ctx,
9191
const DenseTensor& x,
9292
const DenseTensor& y,
9393
int axis,
9494
DenseTensor* out) {
9595
CompareRawKernelImpl<T,
9696
Context,
9797
funcs::EqualFunctor<T>,
98-
funcs::EqualFunctor<T>>(ctx, x, y, axis, out);
98+
funcs::EqualFunctor<T>>(dev_ctx, x, y, axis, out);
9999
}
100100
template <typename T, typename Context>
101-
void NotEqualRawKernel(const Context& ctx,
101+
void NotEqualRawKernel(const Context& dev_ctx,
102102
const DenseTensor& x,
103103
const DenseTensor& y,
104104
int axis,
105105
DenseTensor* out) {
106106
CompareRawKernelImpl<T,
107107
Context,
108108
funcs::NotEqualFunctor<T>,
109-
funcs::NotEqualFunctor<T>>(ctx, x, y, axis, out);
109+
funcs::NotEqualFunctor<T>>(dev_ctx, x, y, axis, out);
110110
}
111111
} // namespace phi
112112

paddle/phi/kernels/legacy/cpu/legacy_generate_proposals_kernel.cc

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ namespace phi {
2828

2929
template <typename T>
3030
std::pair<phi::DenseTensor, phi::DenseTensor> ProposalForOneImage(
31-
const phi::CPUContext &ctx,
31+
const phi::CPUContext &dev_ctx,
3232
const phi::DenseTensor &im_info_slice,
3333
const phi::DenseTensor &anchors,
3434
const phi::DenseTensor &variances,
@@ -44,7 +44,7 @@ std::pair<phi::DenseTensor, phi::DenseTensor> ProposalForOneImage(
4444
// Sort index
4545
phi::DenseTensor index_t;
4646
index_t.Resize({scores_slice.numel()});
47-
int *index = ctx.Alloc<int>(&index_t);
47+
int *index = dev_ctx.Alloc<int>(&index_t);
4848
for (int i = 0; i < scores_slice.numel(); ++i) {
4949
index[i] = i;
5050
}
@@ -65,64 +65,65 @@ std::pair<phi::DenseTensor, phi::DenseTensor> ProposalForOneImage(
6565
bbox_sel.Resize({index_t.numel(), 4});
6666
anchor_sel.Resize({index_t.numel(), 4});
6767
var_sel.Resize({index_t.numel(), 4});
68-
ctx.Alloc<T>(&scores_sel);
69-
ctx.Alloc<T>(&bbox_sel);
70-
ctx.Alloc<T>(&anchor_sel);
71-
ctx.Alloc<T>(&var_sel);
68+
dev_ctx.Alloc<T>(&scores_sel);
69+
dev_ctx.Alloc<T>(&bbox_sel);
70+
dev_ctx.Alloc<T>(&anchor_sel);
71+
dev_ctx.Alloc<T>(&var_sel);
7272

73-
phi::funcs::CPUGather<T>(ctx, scores_slice, index_t, &scores_sel);
74-
phi::funcs::CPUGather<T>(ctx, bbox_deltas_slice, index_t, &bbox_sel);
75-
phi::funcs::CPUGather<T>(ctx, anchors, index_t, &anchor_sel);
76-
phi::funcs::CPUGather<T>(ctx, variances, index_t, &var_sel);
73+
phi::funcs::CPUGather<T>(dev_ctx, scores_slice, index_t, &scores_sel);
74+
phi::funcs::CPUGather<T>(dev_ctx, bbox_deltas_slice, index_t, &bbox_sel);
75+
phi::funcs::CPUGather<T>(dev_ctx, anchors, index_t, &anchor_sel);
76+
phi::funcs::CPUGather<T>(dev_ctx, variances, index_t, &var_sel);
7777

7878
phi::DenseTensor proposals;
7979
proposals.Resize({index_t.numel(), 4});
80-
ctx.Alloc<T>(&proposals);
81-
phi::funcs::BoxCoder<T>(ctx, &anchor_sel, &bbox_sel, &var_sel, &proposals);
80+
dev_ctx.Alloc<T>(&proposals);
81+
phi::funcs::BoxCoder<T>(
82+
dev_ctx, &anchor_sel, &bbox_sel, &var_sel, &proposals);
8283

8384
phi::funcs::ClipTiledBoxes<T>(
84-
ctx, im_info_slice, proposals, &proposals, false);
85+
dev_ctx, im_info_slice, proposals, &proposals, false);
8586

8687
phi::DenseTensor keep;
8788
phi::funcs::FilterBoxes<T>(
88-
ctx, &proposals, min_size, im_info_slice, true, &keep);
89+
dev_ctx, &proposals, min_size, im_info_slice, true, &keep);
8990
// Handle the case when there is no keep index left
9091
if (keep.numel() == 0) {
9192
phi::funcs::SetConstant<phi::CPUContext, T> set_zero;
9293
bbox_sel.Resize({1, 4});
93-
ctx.Alloc<T>(&bbox_sel);
94-
set_zero(ctx, &bbox_sel, static_cast<T>(0));
94+
dev_ctx.Alloc<T>(&bbox_sel);
95+
set_zero(dev_ctx, &bbox_sel, static_cast<T>(0));
9596
phi::DenseTensor scores_filter;
9697
scores_filter.Resize({1, 1});
97-
ctx.Alloc<T>(&scores_filter);
98-
set_zero(ctx, &scores_filter, static_cast<T>(0));
98+
dev_ctx.Alloc<T>(&scores_filter);
99+
set_zero(dev_ctx, &scores_filter, static_cast<T>(0));
99100
return std::make_pair(bbox_sel, scores_filter);
100101
}
101102

102103
phi::DenseTensor scores_filter;
103104
bbox_sel.Resize({keep.numel(), 4});
104105
scores_filter.Resize({keep.numel(), 1});
105-
ctx.Alloc<T>(&bbox_sel);
106-
ctx.Alloc<T>(&scores_filter);
107-
phi::funcs::CPUGather<T>(ctx, proposals, keep, &bbox_sel);
108-
phi::funcs::CPUGather<T>(ctx, scores_sel, keep, &scores_filter);
106+
dev_ctx.Alloc<T>(&bbox_sel);
107+
dev_ctx.Alloc<T>(&scores_filter);
108+
phi::funcs::CPUGather<T>(dev_ctx, proposals, keep, &bbox_sel);
109+
phi::funcs::CPUGather<T>(dev_ctx, scores_sel, keep, &scores_filter);
109110
if (nms_thresh <= 0) {
110111
return std::make_pair(bbox_sel, scores_filter);
111112
}
112113

113114
phi::DenseTensor keep_nms =
114-
phi::funcs::NMS<T>(ctx, &bbox_sel, &scores_filter, nms_thresh, eta);
115+
phi::funcs::NMS<T>(dev_ctx, &bbox_sel, &scores_filter, nms_thresh, eta);
115116

116117
if (post_nms_top_n > 0 && post_nms_top_n < keep_nms.numel()) {
117118
keep_nms.Resize({post_nms_top_n});
118119
}
119120

120121
proposals.Resize({keep_nms.numel(), 4});
121122
scores_sel.Resize({keep_nms.numel(), 1});
122-
ctx.Alloc<T>(&proposals);
123-
ctx.Alloc<T>(&scores_sel);
124-
phi::funcs::CPUGather<T>(ctx, bbox_sel, keep_nms, &proposals);
125-
phi::funcs::CPUGather<T>(ctx, scores_filter, keep_nms, &scores_sel);
123+
dev_ctx.Alloc<T>(&proposals);
124+
dev_ctx.Alloc<T>(&scores_sel);
125+
phi::funcs::CPUGather<T>(dev_ctx, bbox_sel, keep_nms, &proposals);
126+
phi::funcs::CPUGather<T>(dev_ctx, scores_filter, keep_nms, &scores_sel);
126127

127128
return std::make_pair(proposals, scores_sel);
128129
}

paddle/phi/kernels/legacy/cpu/one_hot_kernel.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,20 @@ struct OneHotV2OpFunctor {
2525
const DenseTensor* in_;
2626
DenseTensor* out_;
2727
int depth_;
28-
const DeviceContext& ctx_;
28+
const DeviceContext& dev_ctx_;
2929

3030
OneHotV2OpFunctor(const DenseTensor* in,
3131
DenseTensor* out,
3232
int depth,
33-
const DeviceContext& ctx)
34-
: in_(in), out_(out), depth_(depth), ctx_(ctx) {}
33+
const DeviceContext& dev_ctx)
34+
: in_(in), out_(out), depth_(depth), dev_ctx_(dev_ctx) {}
3535

3636
template <typename OutT>
3737
void apply() const {
3838
auto* p_in_data = in_->data<InT>();
3939
auto numel = in_->numel();
40-
auto* p_out_data = ctx_.template Alloc<OutT>(out_);
41-
funcs::set_constant(ctx_, out_, 0.0);
40+
auto* p_out_data = dev_ctx_.template Alloc<OutT>(out_);
41+
funcs::set_constant(dev_ctx_, out_, 0.0);
4242

4343
for (int i = 0; i < numel; ++i) {
4444
PADDLE_ENFORCE_GE(

paddle/phi/kernels/legacy/gpu/layer_norm_cuda_kernel.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -943,7 +943,7 @@ void HostApplyRMSNorm(V* output,
943943
}
944944

945945
template <typename T, typename Context>
946-
void cuda_rms_norm(const Context& ctx,
946+
void cuda_rms_norm(const Context& dev_ctx,
947947
const DenseTensor& x,
948948
const DenseTensor& scale,
949949
int rows,
@@ -960,7 +960,7 @@ void cuda_rms_norm(const Context& ctx,
960960
cols, \
961961
epsilon, \
962962
const_cast<scalar_t_out*>(scale.data<scalar_t_out>()), \
963-
ctx.stream())
963+
dev_ctx.stream())
964964
// scale.dtype() same as y->dtype()
965965
if (scale.dtype() == phi::DataType::FLOAT32) {
966966
DISPATCH_FWD_CASE(float);
@@ -971,7 +971,7 @@ void cuda_rms_norm(const Context& ctx,
971971
}
972972

973973
template <typename T, typename U, typename V, typename Context>
974-
void HostRMSNormGradient(const Context& ctx,
974+
void HostRMSNormGradient(const Context& dev_ctx,
975975
const V* dout,
976976
const U* invvar,
977977
const DenseTensor& input,
@@ -992,7 +992,7 @@ void HostRMSNormGradient(const Context& ctx,
992992
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
993993
auto place = input.place();
994994
DenseTensor part_grad_gamma =
995-
phi::Empty<float, Context>(ctx, {part_size, n2});
995+
phi::Empty<float, Context>(dev_ctx, {part_size, n2});
996996
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
997997
dout,
998998
input.data<T>(),
@@ -1038,7 +1038,7 @@ void HostRMSNormGradient(const Context& ctx,
10381038
}
10391039

10401040
template <typename T, typename Context>
1041-
void cuda_rms_norm_gradient(const Context& ctx,
1041+
void cuda_rms_norm_gradient(const Context& dev_ctx,
10421042
const DenseTensor& x,
10431043
const DenseTensor& scale,
10441044
const DenseTensor& invvar,
@@ -1050,7 +1050,7 @@ void cuda_rms_norm_gradient(const Context& ctx,
10501050
DenseTensor* grad_scale) {
10511051
#define DISPATCH_BWD_CASE(scalar_t_out) \
10521052
HostRMSNormGradient<T, float, scalar_t_out, Context>( \
1053-
ctx, \
1053+
dev_ctx, \
10541054
dy.data<scalar_t_out>(), \
10551055
invvar.data<float>(), \
10561056
x, \
@@ -1060,7 +1060,7 @@ void cuda_rms_norm_gradient(const Context& ctx,
10601060
epsilon, \
10611061
grad_x->data<T>(), \
10621062
grad_scale->data<scalar_t_out>(), \
1063-
ctx.stream())
1063+
dev_ctx.stream())
10641064
if (scale.dtype() == phi::DataType::FLOAT32) {
10651065
DISPATCH_BWD_CASE(float);
10661066
} else if (scale.dtype() == phi::DataType::BFLOAT16) {

0 commit comments

Comments
 (0)