@@ -32,13 +32,8 @@ __global__ void LoadGDHessFvalue(const real *pInsGD, const real *pInsHess, int n
3232__global__ void FirstFeaGain (const unsigned int *pEachFeaStartPosEachNode, int numFeaStartPos, real *pGainOnEachFeaValue, uint numFeaValue);
3333
3434// helper functions
35- template <class T >
36- __device__ bool NeedUpdate (T &RChildHess, T &LChildHess)
37- {
38- if (LChildHess >= DeviceSplitter::min_child_weight && RChildHess >= DeviceSplitter::min_child_weight)
39- return true ;
40- return false ;
41- }
35+ __device__ bool NeedCompGain (double RChildHess, double LChildHess);
36+ __device__ double ComputeGain (double tempGD, double tempHess, real lambda, double rChildGD, double rChildHess, double parentGD, double parentHess);
4237
4338template <class T >
4439__global__ void ComputeGainDense (const nodeStat *pSNodeStat, const int *pid2SNPos, real lambda,
@@ -59,71 +54,89 @@ __global__ void ComputeGainDense(const nodeStat *pSNodeStat, const int *pid2SNPo
5954 int snPos = pid2SNPos[pid];
6055 ECHECKER (snPos);
6156
62- if (gTid == 0 )
63- {
64- // assign gain to 0 to the first feature value
65- pGainOnEachFeaValue[gTid ] = 0 ;
66- return ;
67- }
68-
69- // if the previous fea value is the same as the current fea value, gain is 0 for the current fea value.
70- real preFvalue = pDenseFeaValue[gTid - 1 ], curFvalue = pDenseFeaValue[gTid ];
71- if (preFvalue - curFvalue <= rt_2eps && preFvalue - curFvalue >= -rt_2eps)// ############## backwards is not considered!
72- {// avoid same feature value different gain issue
73- pGainOnEachFeaValue[gTid ] = 0 ;
74- return ;
75- }
76-
77- int exclusiveSumPos = gTid - 1 ;// following xgboost using exclusive sum on gd and hess
78-
79- // forward consideration (fvalues are sorted descendingly)
80- double rChildGD = pGDPrefixSumOnEachFeaValue[exclusiveSumPos];
81- double rChildHess = pHessPrefixSumOnEachFeaValue[exclusiveSumPos];
82- double parentGD = pSNodeStat[snPos].sum_gd ;
83- double parentHess = pSNodeStat[snPos].sum_hess ;
84- double tempGD = parentGD - rChildGD;
85- double tempHess = parentHess - rChildHess;
86- bool needUpdate = NeedUpdate (rChildHess, tempHess);
87- if (needUpdate == true )// need to compute the gain
88- {
89- ECHECKER (tempHess > 0 );
90- ECHECKER (parentHess > 0 );
91- double tempGain = (tempGD * tempGD)/(tempHess + lambda) +
92- (rChildGD * rChildGD)/(rChildHess + lambda) -
93- (parentGD * parentGD)/(parentHess + lambda);
94- pGainOnEachFeaValue[gTid ] = tempGain;
95- }
96- else {
97- // assign gain to 0
98- pGainOnEachFeaValue[gTid ] = 0 ;
99- }
100-
101- // backward consideration
10257 int segLen = pEachFeaLenEachNode[segId];
10358 uint segStartPos = pEachFeaStartEachNode[segId];
104- uint lastFvaluePos = segStartPos + segLen - 1 ;
105- double totalMissingGD = parentGD - pGDPrefixSumOnEachFeaValue[lastFvaluePos];
106- double totalMissingHess = parentHess - pHessPrefixSumOnEachFeaValue[lastFvaluePos];
107- if (totalMissingHess < 1 )// there is no instance with missing values
108- return ;
109- // missing values to the right child
110- rChildGD += totalMissingGD;
111- rChildHess += totalMissingHess;
112- tempGD = parentGD - rChildGD;
113- tempHess = parentHess - rChildHess;
114- needUpdate = NeedUpdate (rChildHess, tempHess);
115- if (needUpdate == true ){
116- ECHECKER (tempHess > 0 );
117- ECHECKER (parentHess > 0 );
118- double tempGain = (tempGD * tempGD)/(tempHess + lambda) +
119- (rChildGD * rChildGD)/(rChildHess + lambda) -
120- (parentGD * parentGD)/(parentHess + lambda);
121-
122- if (tempGain > 0 && tempGain - pGainOnEachFeaValue[gTid ] > 0.1 ){
123- pGainOnEachFeaValue[gTid ] = tempGain;
59+ uint lastFvaluePos = segStartPos + segLen - 1 ;
60+ double parentGD = pSNodeStat[snPos].sum_gd ;
61+ double parentHess = pSNodeStat[snPos].sum_hess ;
62+ double totalMissingGD = parentGD - pGDPrefixSumOnEachFeaValue[lastFvaluePos];
63+ double totalMissingHess = parentHess - pHessPrefixSumOnEachFeaValue[lastFvaluePos];
64+
65+ if (gTid == segStartPos){// include all sum statistics; store the gain at the first pos of each segment
66+ // check default to left
67+ double totalGD = pGDPrefixSumOnEachFeaValue[lastFvaluePos];
68+ double totalHess = pHessPrefixSumOnEachFeaValue[lastFvaluePos];
69+ bool needUpdate = NeedCompGain (totalHess, totalMissingHess);
70+ real all2Left = 0 , all2Right = -1.0 ;
71+ if (needUpdate == true ){
72+ CONCHECKER (totalHess > 0 );
73+ CONCHECKER (totalMissingHess > 0 );
74+ all2Left = ComputeGain (totalMissingGD, totalMissingHess, lambda, totalGD, totalHess, parentGD, parentHess);
75+ all2Right = ComputeGain (totalGD, totalHess, lambda, totalMissingGD, totalMissingHess, parentGD, parentHess);
76+ }
77+
78+ // check default to right
79+ if (all2Left < all2Right){
80+ pGainOnEachFeaValue[gTid ] = all2Right;
12481 pDefault2Right[gTid ] = true ;
12582 }
126- }
83+ else
84+ pGainOnEachFeaValue[gTid ] = all2Left;
85+ // if(7773118 == gTid)
86+ // printf("gain=%f, totalHess=%f, totalMissHess=%f, last+1=%f, last-1=%f, last=%f, last-2=%f, last=%u\n",
87+ // pGainOnEachFeaValue[gTid], totalHess, totalMissingHess, pHessPrefixSumOnEachFeaValue[lastFvaluePos+1],
88+ // pHessPrefixSumOnEachFeaValue[lastFvaluePos],
89+ // pHessPrefixSumOnEachFeaValue[lastFvaluePos-1], pHessPrefixSumOnEachFeaValue[lastFvaluePos-2], lastFvaluePos);
90+ }
91+ else {
92+ // if the previous fea value is the same as the current fea value, gain is 0 for the current fea value.
93+ real preFvalue = pDenseFeaValue[gTid - 1 ], curFvalue = pDenseFeaValue[gTid ];
94+ if (preFvalue - curFvalue <= rt_2eps && preFvalue - curFvalue >= -rt_2eps)// ############## backwards is not considered!
95+ {// avoid same feature value different gain issue
96+ pGainOnEachFeaValue[gTid ] = 0 ;
97+ return ;
98+ }
99+
100+ int exclusiveSumPos = gTid - 1 ;// following xgboost using exclusive sum on gd and hess
101+
102+ // forward consideration (fvalues are sorted descendingly)
103+ double rChildGD = pGDPrefixSumOnEachFeaValue[exclusiveSumPos];
104+ double rChildHess = pHessPrefixSumOnEachFeaValue[exclusiveSumPos];
105+ double tempGD = parentGD - rChildGD;
106+ double tempHess = parentHess - rChildHess;
107+ bool needUpdate = NeedCompGain (rChildHess, tempHess);
108+ if (needUpdate == true )// need to compute the gain
109+ {
110+ ECHECKER (tempHess > 0 );
111+ ECHECKER (parentHess > 0 );
112+ double tempGain = ComputeGain (tempGD, tempHess, lambda, rChildGD, rChildHess, parentGD, parentHess);
113+ pGainOnEachFeaValue[gTid ] = tempGain;
114+ }
115+ else {
116+ // assign gain to 0
117+ pGainOnEachFeaValue[gTid ] = 0 ;
118+ }
119+
120+ // backward consideration
121+ if (totalMissingHess < 1 )// there is no instance with missing values
122+ return ;
123+ // missing values to the right child
124+ rChildGD += totalMissingGD;
125+ rChildHess += totalMissingHess;
126+ tempGD = parentGD - rChildGD;
127+ tempHess = parentHess - rChildHess;
128+ needUpdate = NeedCompGain (rChildHess, tempHess);
129+ if (needUpdate == true ){
130+ ECHECKER (tempHess > 0 );
131+ ECHECKER (parentHess > 0 );
132+ double tempGain = ComputeGain (tempGD, tempHess, lambda, rChildGD, rChildHess, parentGD, parentHess);
133+
134+ if (tempGain > 0 && tempGain - pGainOnEachFeaValue[gTid ] > 0.1 ){
135+ pGainOnEachFeaValue[gTid ] = tempGain;
136+ pDefault2Right[gTid ] = true ;
137+ }
138+ }
139+ }// end of forward and backward consideration
127140}
128141
129142/* *
@@ -150,46 +163,70 @@ __global__ void FindSplitInfo(const uint *pEachFeaStartPosEachNode, const T *pEa
150163 uint segId = pnKey[key];
151164 uint bestFeaId = segId % numFea;
152165 CONCHECKER (bestFeaId < numFea);
153-
154-
155166 pBestSplitPoint[snPos].m_nFeatureId = bestFeaId;
156- pBestSplitPoint[snPos].m_fSplitValue = 0 .5f * (pDenseFeaValue[key] + pDenseFeaValue[key - 1 ]);
157167 pBestSplitPoint[snPos].m_bDefault2Right = false ;
158168
159- // child node stat
160- int idxPreSum = key - 1 ;// follow xgboost using exclusive
161- if (pDefault2Right[key] == false ){
162- pLChildStat[snPos].sum_gd = snNodeStat[snPos].sum_gd - pPrefixSumGD[idxPreSum];
163- pLChildStat[snPos].sum_hess = snNodeStat[snPos].sum_hess - pPrefixSumHess[idxPreSum];
164- pRChildStat[snPos].sum_gd = pPrefixSumGD[idxPreSum];
165- pRChildStat[snPos].sum_hess = pPrefixSumHess[idxPreSum];
169+ // handle all to left/right case
170+ uint segStartPos = pEachFeaStartPosEachNode[segId];
171+ T segLen = pEachFeaLenEachNode[segId];
172+ uint lastFvaluePos = segStartPos + segLen - 1 ;
173+ if (key == 0 || (key > 0 && pnKey[key] != pnKey[key - 1 ])){// first element of the feature
174+ const real gap = fabs (pDenseFeaValue[key]) + DeviceSplitter::rt_eps;
175+ // printf("############## %u v.s. %u; %u\n", bestFeaId, bestFeaId, pPrefixSumHess[key]);
176+ // printf("missing %f all to one node: fid=%u, pnKey[%u]=%u != pnKey[%u]=%u, segLen=%u, parentHess=%f, startPos=%u..........................\n",
177+ // pPrefixSumHess[key], bestFeaId, key, pnKey[key], key-1, pnKey[key-1], pEachFeaLenEachNode[segId],
178+ // snNodeStat[snPos].sum_hess, pEachFeaStartPosEachNode[segId]);
179+ if (pDefault2Right[key] == true ){// all non-missing to left
180+ pBestSplitPoint[snPos].m_bDefault2Right = true ;
181+ pBestSplitPoint[snPos].m_fSplitValue = pDenseFeaValue[lastFvaluePos] + gap;
182+ pLChildStat[snPos].sum_gd = pPrefixSumGD[lastFvaluePos];
183+ pLChildStat[snPos].sum_hess = pPrefixSumHess[lastFvaluePos];
184+ pRChildStat[snPos].sum_gd = snNodeStat[snPos].sum_gd - pPrefixSumGD[lastFvaluePos];
185+ pRChildStat[snPos].sum_hess = snNodeStat[snPos].sum_hess - pPrefixSumHess[lastFvaluePos];
186+ }
187+ else {// all non-missing to right
188+ pBestSplitPoint[snPos].m_fSplitValue = pDenseFeaValue[lastFvaluePos] - gap;
189+ pLChildStat[snPos].sum_gd = snNodeStat[snPos].sum_gd - pPrefixSumGD[lastFvaluePos];
190+ pLChildStat[snPos].sum_hess = snNodeStat[snPos].sum_hess - pPrefixSumHess[lastFvaluePos];
191+ pRChildStat[snPos].sum_gd = pPrefixSumGD[lastFvaluePos];
192+ pRChildStat[snPos].sum_hess = pPrefixSumHess[lastFvaluePos];
193+ }
166194 }
167- else {
168- pBestSplitPoint[snPos].m_bDefault2Right = true ;
169-
170- real parentGD = snNodeStat[snPos].sum_gd ;
171- real parentHess = snNodeStat[snPos].sum_hess ;
172-
173- uint segStartPos = pEachFeaStartPosEachNode[segId];
174- T segLen = pEachFeaLenEachNode[segId];
175- uint lastFvaluePos = segStartPos + segLen - 1 ;
176- real totalMissingGD = parentGD - pPrefixSumGD[lastFvaluePos];
177- real totalMissingHess = parentHess - pPrefixSumHess[lastFvaluePos];
178-
179- double rChildGD = totalMissingGD + pPrefixSumGD[idxPreSum];
180- real rChildHess = totalMissingHess + pPrefixSumHess[idxPreSum];
181- ECHECKER (rChildHess);
182- real lChildGD = parentGD - rChildGD;
183- real lChildHess = parentHess - rChildHess;
184- ECHECKER (lChildHess);
185-
186- pRChildStat[snPos].sum_gd = rChildGD;
187- pRChildStat[snPos].sum_hess = rChildHess;
188- pLChildStat[snPos].sum_gd = lChildGD;
189- pLChildStat[snPos].sum_hess = lChildHess;
195+ else {// non-first element of the feature
196+ pBestSplitPoint[snPos].m_fSplitValue = 0 .5f * (pDenseFeaValue[key] + pDenseFeaValue[key - 1 ]);
197+
198+ // child node stat
199+ int idxPreSum = key - 1 ;// follow xgboost using exclusive
200+ if (pDefault2Right[key] == false ){
201+ pLChildStat[snPos].sum_gd = snNodeStat[snPos].sum_gd - pPrefixSumGD[idxPreSum];
202+ pLChildStat[snPos].sum_hess = snNodeStat[snPos].sum_hess - pPrefixSumHess[idxPreSum];
203+ pRChildStat[snPos].sum_gd = pPrefixSumGD[idxPreSum];
204+ pRChildStat[snPos].sum_hess = pPrefixSumHess[idxPreSum];
205+ }
206+ else {
207+ pBestSplitPoint[snPos].m_bDefault2Right = true ;
208+
209+ real parentGD = snNodeStat[snPos].sum_gd ;
210+ real parentHess = snNodeStat[snPos].sum_hess ;
211+
212+ real totalMissingGD = parentGD - pPrefixSumGD[lastFvaluePos];
213+ real totalMissingHess = parentHess - pPrefixSumHess[lastFvaluePos];
214+
215+ double rChildGD = totalMissingGD + pPrefixSumGD[idxPreSum];
216+ real rChildHess = totalMissingHess + pPrefixSumHess[idxPreSum];
217+ ECHECKER (rChildHess);
218+ real lChildGD = parentGD - rChildGD;
219+ real lChildHess = parentHess - rChildHess;
220+ ECHECKER (lChildHess);
221+
222+ pRChildStat[snPos].sum_gd = rChildGD;
223+ pRChildStat[snPos].sum_hess = rChildHess;
224+ pLChildStat[snPos].sum_gd = lChildGD;
225+ pLChildStat[snPos].sum_hess = lChildHess;
226+ }
227+ ECHECKER (pLChildStat[snPos].sum_hess );
228+ ECHECKER (pRChildStat[snPos].sum_hess );
190229 }
191- ECHECKER (pLChildStat[snPos].sum_hess );
192- ECHECKER (pRChildStat[snPos].sum_hess );
193230
194231// printf("split: f=%d, value=%f, gain=%f, gd=%f v.s. %f, hess=%f v.s. %f, buffId=%d, key=%d, pid=%d, df2Left=%d\n", bestFeaId, pBestSplitPoint[snPos].m_fSplitValue,
195232// pBestSplitPoint[snPos].m_fGain, pLChildStat[snPos].sum_gd, pRChildStat[snPos].sum_gd, pLChildStat[snPos].sum_hess,
0 commit comments