@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
15
#include " paddle/fluid/operators/collective/c_allgather_op.h"
16
+ #include " paddle/fluid/operators/mlu/mlu_baseop.h"
16
17
17
18
#if defined(PADDLE_WITH_CNCL)
18
19
#include " paddle/fluid/platform/collective_helper.h"
@@ -27,15 +28,14 @@ template <typename T>
27
28
class CAllGatherOpMLUKernel : public framework ::OpKernel<T> {
28
29
public:
29
30
void Compute (const framework::ExecutionContext& ctx) const override {
31
+ auto place = ctx.GetPlace ();
32
+ auto dev_ctx = platform::DeviceContextPool::Instance ().Get (place);
30
33
#if defined(PADDLE_WITH_CNCL)
31
- auto x = ctx.Input <framework::Tensor>(" X" );
32
- auto out = ctx.Output <framework::Tensor>(" Out" );
33
- cnclDataType_t dtype =
34
- platform::ToCNCLDataType (framework::TransToProtoVarType (x->dtype ()));
34
+ auto x = ctx.Input <phi::DenseTensor>(" X" );
35
+ auto out = ctx.Output <phi::DenseTensor>(" Out" );
35
36
36
37
int nranks = ctx.Attr <int >(" nranks" );
37
38
int rid = ctx.Attr <int >(" ring_id" );
38
- auto place = ctx.GetPlace ();
39
39
auto comm = platform::CNCLCommContext::Instance ().Get (rid, place);
40
40
PADDLE_ENFORCE_EQ (
41
41
nranks,
@@ -48,19 +48,56 @@ class CAllGatherOpMLUKernel : public framework::OpKernel<T> {
48
48
out->mutable_data <T>(out_dims, place);
49
49
50
50
uint32_t send_numel = x->numel ();
51
- void * send_buff = reinterpret_cast <void *>(const_cast <T*>(x->data <T>()));
52
- void * recv_buff = reinterpret_cast <void *>(out->data <T>());
51
+ void * send_buff;
52
+ void * recv_buff;
53
+ phi::DenseTensor in_tensor, out_tensor;
54
+ if (framework::TransToProtoVarType (x->dtype ()) ==
55
+ framework::proto::VarType::INT64) {
56
+ // cast from int64 to int32 since cncl do not support int64
57
+ in_tensor.mutable_data <int32_t >(x->dims (), place);
58
+ out_tensor.mutable_data <int32_t >(out->dims (), place);
59
+ MLUCnnlTensorDesc x_int64_desc (*x);
60
+ MLUCnnlTensorDesc x_int32_desc (in_tensor);
61
+ cnnlCastDataType_t cast_type = GetCastDataType (VT::INT64, VT::INT32);
62
+ MLUCnnl::Cast (ctx,
63
+ cast_type,
64
+ x_int64_desc.get (),
65
+ GetBasePtr (x),
66
+ x_int32_desc.get (),
67
+ GetBasePtr (&in_tensor));
68
+ send_buff = reinterpret_cast <void *>(in_tensor.data <int32_t >());
69
+ recv_buff = reinterpret_cast <void *>(out_tensor.data <int32_t >());
70
+ } else {
71
+ in_tensor.ShareDataWith (*x);
72
+ out_tensor.ShareDataWith (*out);
73
+ send_buff = reinterpret_cast <void *>(in_tensor.data <T>());
74
+ recv_buff = reinterpret_cast <void *>(out_tensor.data <T>());
75
+ }
53
76
54
77
mluStream stream = nullptr ;
55
78
if (ctx.Attr <bool >(" use_calc_stream" )) {
56
- auto dev_ctx = platform::DeviceContextPool::Instance ().Get (place);
57
79
stream = static_cast <platform::MLUDeviceContext*>(dev_ctx)->stream ();
58
80
} else {
59
81
stream = comm->stream ();
60
82
}
83
+ cnclDataType_t dtype = platform::ToCNCLDataType (
84
+ framework::TransToProtoVarType (in_tensor.dtype ()));
61
85
62
86
PADDLE_ENFORCE_MLU_SUCCESS (cnclAllGather (
63
87
send_buff, recv_buff, send_numel, dtype, comm->comm (), stream));
88
+ if (framework::TransToProtoVarType (x->dtype ()) ==
89
+ framework::proto::VarType::INT64) {
90
+ // cast back from int64 out_tensor to out
91
+ MLUCnnlTensorDesc out_int64_desc (*out);
92
+ MLUCnnlTensorDesc out_int32_desc (out_tensor);
93
+ cnnlCastDataType_t cast_type = GetCastDataType (VT::INT32, VT::INT64);
94
+ MLUCnnl::Cast (ctx,
95
+ cast_type,
96
+ out_int32_desc.get (),
97
+ GetBasePtr (&out_tensor),
98
+ out_int64_desc.get (),
99
+ GetBasePtr (out));
100
+ }
64
101
#else
65
102
PADDLE_THROW (platform::errors::PreconditionNotMet (
66
103
" PaddlePaddle should compile with MLU." ));
@@ -80,4 +117,5 @@ REGISTER_OP_MLU_KERNEL(c_allgather,
80
117
ops::CAllGatherOpMLUKernel<int >,
81
118
ops::CAllGatherOpMLUKernel<int8_t >,
82
119
ops::CAllGatherOpMLUKernel<int16_t >,
120
+ ops::CAllGatherOpMLUKernel<int64_t >,
83
121
ops::CAllGatherOpMLUKernel<plat::float16>);
0 commit comments