@@ -30,119 +30,126 @@ void CSoftmaxWithEntropyKernel(const Context& dev_ctx,
30
30
const DenseTensor& logits_in,
31
31
const DenseTensor& label_in,
32
32
int64_t ignore_index,
33
- int ring_id,
34
33
int rank,
35
34
int nranks,
36
35
DenseTensor* softmax,
37
36
DenseTensor* loss) {
38
- const int rid = ring_id;
39
- auto map = distributed::ProcessGroupMapFromGid::getInstance ();
40
- if (map->has (rid)) {
41
- const phi::DenseTensor* logits = &logits_in;
42
- const phi::DenseTensor* labels = &label_in;
43
- auto softmax_dims = softmax->dims ();
44
- auto loss_dims = loss->dims ();
45
-
46
- const int rid = ring_id;
47
-
48
- distributed::ProcessGroup* pg = map->get (rid);
49
- distributed::AllreduceOptions opts;
50
-
51
- // allocate memory on device.
52
- const auto & logits_dims = logits->dims ();
53
-
54
- const int axis = logits_dims.size () - 1 ;
55
- const int N = phi::funcs::SizeToAxis (axis, logits_dims);
56
- const int D = phi::funcs::SizeFromAxis (axis, logits_dims);
57
-
58
- auto logits_2d = std::make_shared<phi::DenseTensor>();
59
- auto labels_1d = std::make_shared<phi::DenseTensor>();
60
- logits_2d->ShareDataWith (*logits).Resize ({N, D});
61
- labels_1d->ShareDataWith (*labels).Resize ({N});
62
- paddle::Tensor logits_2d_tensor (logits_2d), labels_1d_tensor (labels_1d);
63
-
64
- // step 1, obtain logit_max
65
- auto logits_2d_max_tensor = logits_2d_tensor.max ({1 }, true );
66
- std::vector<phi::DenseTensor> in_out;
67
- in_out.push_back (*reinterpret_cast <phi::DenseTensor*>(
68
- logits_2d_max_tensor.impl ().get ()));
69
- opts.reduce_op = distributed::ReduceOp::MAX;
70
- pg->AllReduce (in_out, in_out, opts)->Synchronize ();
71
-
72
- // step 2, obtain logit - logit_max
73
- auto logits_2d_sub_max = paddle::experimental::clip (
74
- logits_2d_tensor - logits_2d_max_tensor, -64 ., 0 .);
75
-
76
- // step 3, obtain predict target
77
- const int start_index = rank * D;
78
- auto start_index_tensor =
79
- paddle::experimental::full_like (labels_1d_tensor,
80
- start_index,
81
- labels_1d_tensor.dtype (),
82
- labels_1d_tensor.place ());
83
- auto end_index_tensor =
84
- paddle::experimental::full_like (labels_1d_tensor,
85
- start_index + D,
86
- labels_1d_tensor.dtype (),
87
- labels_1d_tensor.place ());
88
- auto labels_1d_mask = paddle::experimental::logical_and (
89
- labels_1d_tensor.greater_equal (start_index_tensor),
90
- labels_1d_tensor.less_than (end_index_tensor));
91
- auto real_label_tensor = (labels_1d_tensor - start_index_tensor)
92
- .multiply (paddle::experimental::cast (
93
- labels_1d_mask, labels_1d_tensor.dtype ()));
94
-
95
- auto predicted_logits_tensor =
96
- logits_2d_sub_max
97
- .multiply (paddle::experimental::cast (
98
- paddle::experimental::one_hot (real_label_tensor, D),
99
- logits_2d_sub_max.dtype ()))
100
- .sum ({1 }, logits_2d_sub_max.dtype (), false )
101
- .multiply (paddle::experimental::cast (labels_1d_mask,
102
- logits_2d_sub_max.dtype ()));
103
-
104
- in_out.clear ();
105
- in_out.push_back (*reinterpret_cast <phi::DenseTensor*>(
106
- predicted_logits_tensor.impl ().get ()));
107
- opts.reduce_op = distributed::ReduceOp::SUM;
108
- pg->AllReduce (in_out, in_out, opts)->Synchronize ();
109
-
110
- // step 4, obtain exp(logit)
111
- auto softmax_2d_tensor = logits_2d_sub_max.exp ();
112
-
113
- // step 5, obtain sum_exp_logits
114
- auto sum_exp_logits_tensor =
115
- softmax_2d_tensor.sum ({1 }, softmax_2d_tensor.dtype (), false );
116
-
117
- in_out.clear ();
118
- in_out.push_back (*reinterpret_cast <phi::DenseTensor*>(
119
- sum_exp_logits_tensor.impl ().get ()));
120
- opts.reduce_op = distributed::ReduceOp::SUM;
121
- pg->AllReduce (in_out, in_out, opts)->Synchronize ();
122
-
123
- auto softmax_out = softmax_2d_tensor.divide (
124
- paddle::experimental::reshape (sum_exp_logits_tensor, {N, 1 }));
125
- auto labels_1d_not_equal_ignore = labels_1d_tensor.not_equal (
126
- paddle::experimental::full_like (labels_1d_tensor,
127
- ignore_index,
128
- labels_1d_tensor.dtype (),
129
- labels_1d_tensor.place ()));
130
- auto loss_out =
131
- (sum_exp_logits_tensor.log () - predicted_logits_tensor)
132
- .multiply (paddle::experimental::cast (
133
- labels_1d_not_equal_ignore, sum_exp_logits_tensor.dtype ()));
134
- softmax
135
- ->ShareDataWith (
136
- *reinterpret_cast <phi::DenseTensor*>(softmax_out.impl ().get ()))
137
- .Resize (softmax_dims);
138
- loss->ShareDataWith (
139
- *reinterpret_cast <phi::DenseTensor*>(loss_out.impl ().get ()))
140
- .Resize (loss_dims);
141
- } else {
142
- PADDLE_THROW (
143
- common::errors::Unavailable (" CustomDevice c_softmax_with_cross_entropy "
144
- " only support ProcessGroup" ));
145
- }
37
+ auto comm = reinterpret_cast <phi::distributed::XCCLCommContext*>(
38
+ dev_ctx.GetCommContext ());
39
+ PADDLE_ENFORCE_NE (comm,
40
+ nullptr ,
41
+ common::errors::Unavailable (
42
+ " XCCLCommContext is nullptr, collective op should "
43
+ " has ring_id attr." ));
44
+
45
+ const phi::DenseTensor* logits = &logits_in;
46
+ const phi::DenseTensor* labels = &label_in;
47
+ auto softmax_dims = softmax->dims ();
48
+ auto loss_dims = loss->dims ();
49
+
50
+ const int axis = logits->dims ().size () - 1 ;
51
+ const int N = phi::funcs::SizeToAxis (axis, logits->dims ());
52
+ const int D = phi::funcs::SizeFromAxis (axis, logits->dims ());
53
+
54
+ auto logits_2d = std::make_shared<phi::DenseTensor>();
55
+ auto labels_1d = std::make_shared<phi::DenseTensor>();
56
+ logits_2d->ShareDataWith (*logits).Resize ({N, D});
57
+ labels_1d->ShareDataWith (*labels).Resize ({N});
58
+ paddle::Tensor logits_2d_tensor (logits_2d), labels_1d_tensor (labels_1d);
59
+
60
+ // step 1, obtain logit_max
61
+ auto logits_2d_max_tensor = logits_2d_tensor.max ({1 }, true );
62
+ auto logits_2d_max =
63
+ reinterpret_cast <phi::DenseTensor*>(logits_2d_max_tensor.impl ().get ());
64
+ auto & stream = *dev_ctx.GetStream ();
65
+ phi::DeviceManager::CCLAllReduce (dev_ctx.GetPlace ().GetDeviceType (),
66
+ logits_2d_max->data <float >(),
67
+ logits_2d_max->data <float >(),
68
+ logits_2d_max->numel (),
69
+ logits_2d_max->dtype (),
70
+ phi::ccl::CCLReduceOp::MAX,
71
+ comm->GetXcclComm (),
72
+ stream);
73
+
74
+ // step 2, obtain logit - logit_max
75
+ auto logits_2d_sub_max = paddle::experimental::clip (
76
+ logits_2d_tensor - logits_2d_max_tensor, -64 ., 0 .);
77
+
78
+ // step 3, obtain predict target
79
+ const int start_index = rank * D;
80
+ auto start_index_tensor =
81
+ paddle::experimental::full_like (labels_1d_tensor,
82
+ start_index,
83
+ labels_1d_tensor.dtype (),
84
+ labels_1d_tensor.place ());
85
+ auto end_index_tensor =
86
+ paddle::experimental::full_like (labels_1d_tensor,
87
+ start_index + D,
88
+ labels_1d_tensor.dtype (),
89
+ labels_1d_tensor.place ());
90
+ auto labels_1d_mask = paddle::experimental::logical_and (
91
+ labels_1d_tensor.greater_equal (start_index_tensor),
92
+ labels_1d_tensor.less_than (end_index_tensor));
93
+ auto real_label_tensor = (labels_1d_tensor - start_index_tensor)
94
+ .multiply (paddle::experimental::cast (
95
+ labels_1d_mask, labels_1d_tensor.dtype ()));
96
+
97
+ auto predicted_logits_tensor =
98
+ logits_2d_sub_max
99
+ .multiply (paddle::experimental::cast (
100
+ paddle::experimental::one_hot (real_label_tensor, D),
101
+ logits_2d_sub_max.dtype ()))
102
+ .sum ({1 }, logits_2d_sub_max.dtype (), false )
103
+ .multiply (paddle::experimental::cast (labels_1d_mask,
104
+ logits_2d_sub_max.dtype ()));
105
+
106
+ auto predicted_logits =
107
+ reinterpret_cast <phi::DenseTensor*>(predicted_logits_tensor.impl ().get ());
108
+ phi::DeviceManager::CCLAllReduce (dev_ctx.GetPlace ().GetDeviceType (),
109
+ predicted_logits->data <float >(),
110
+ predicted_logits->data <float >(),
111
+ predicted_logits->numel (),
112
+ predicted_logits->dtype (),
113
+ phi::ccl::CCLReduceOp::SUM,
114
+ comm->GetXcclComm (),
115
+ stream);
116
+
117
+ // step 4, obtain exp(logit)
118
+ auto softmax_2d_tensor = logits_2d_sub_max.exp ();
119
+
120
+ // step 5, obtain sum_exp_logits
121
+ auto sum_exp_logits_tensor =
122
+ softmax_2d_tensor.sum ({1 }, softmax_2d_tensor.dtype (), false );
123
+
124
+ auto sum_exp_logits =
125
+ reinterpret_cast <phi::DenseTensor*>(sum_exp_logits_tensor.impl ().get ());
126
+ phi::DeviceManager::CCLAllReduce (dev_ctx.GetPlace ().GetDeviceType (),
127
+ sum_exp_logits->data <float >(),
128
+ sum_exp_logits->data <float >(),
129
+ sum_exp_logits->numel (),
130
+ sum_exp_logits->dtype (),
131
+ phi::ccl::CCLReduceOp::SUM,
132
+ comm->GetXcclComm (),
133
+ stream);
134
+
135
+ auto softmax_out = softmax_2d_tensor.divide (
136
+ paddle::experimental::reshape (sum_exp_logits_tensor, {N, 1 }));
137
+ auto labels_1d_not_equal_ignore = labels_1d_tensor.not_equal (
138
+ paddle::experimental::full_like (labels_1d_tensor,
139
+ ignore_index,
140
+ labels_1d_tensor.dtype (),
141
+ labels_1d_tensor.place ()));
142
+ auto loss_out =
143
+ (sum_exp_logits_tensor.log () - predicted_logits_tensor)
144
+ .multiply (paddle::experimental::cast (labels_1d_not_equal_ignore,
145
+ sum_exp_logits_tensor.dtype ()));
146
+ softmax
147
+ ->ShareDataWith (
148
+ *reinterpret_cast <phi::DenseTensor*>(softmax_out.impl ().get ()))
149
+ .Resize (softmax_dims);
150
+ loss->ShareDataWith (
151
+ *reinterpret_cast <phi::DenseTensor*>(loss_out.impl ().get ()))
152
+ .Resize (loss_dims);
146
153
}
147
154
} // namespace phi
148
155
0 commit comments