|
193 | 193 | { |
194 | 194 | "cell_type": "code", |
195 | 195 | "execution_count": null, |
| 196 | + "metadata": {}, |
| 197 | + "outputs": [], |
| 198 | + "source": [] |
| 199 | + }, |
| 200 | + { |
| 201 | + "cell_type": "code", |
| 202 | + "execution_count": 7, |
196 | 203 | "metadata": { |
197 | 204 | "ExecuteTime": { |
198 | | - "end_time": "2021-09-01T17:37:22.054020Z", |
199 | | - "start_time": "2021-09-01T17:29:18.371Z" |
| 205 | + "end_time": "2021-09-01T17:49:13.840143Z", |
| 206 | + "start_time": "2021-09-01T17:45:39.855056Z" |
200 | 207 | } |
201 | 208 | }, |
202 | | - "outputs": [], |
| 209 | + "outputs": [ |
| 210 | + { |
| 211 | + "ename": "TypeError", |
| 212 | + "evalue": "'>=' not supported between instances of 'numpy.ndarray' and 'str'", |
| 213 | + "output_type": "error", |
| 214 | + "traceback": [ |
| 215 | + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", |
| 216 | + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", |
| 217 | + "\u001b[0;32m<ipython-input-7-8582c77c0a6d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m n_split_grid=20, n_screen_grid=5, n_feature_search=10)\n\u001b[1;32m 3\u001b[0m \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_y\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mpred_train\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict_proba\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_x\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mpred_test\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict_proba\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_x\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mroc_auc_score\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_y\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpred_train\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mravel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mroc_auc_score\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_y\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpred_test\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mravel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
| 218 | + "\u001b[0;32m~/anaconda2_local/envs/py37/lib/python3.7/site-packages/simtree/mobtree.py\u001b[0m in \u001b[0;36mpredict_proba\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 651\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 652\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpredict_proba\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 653\u001b[0;31m \u001b[0mproba\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecision_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 654\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mproba\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mproba\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 655\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", |
| 219 | + "\u001b[0;32m~/anaconda2_local/envs/py37/lib/python3.7/site-packages/simtree/mobtree.py\u001b[0m in \u001b[0;36mdecision_function\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 503\u001b[0m \u001b[0mnode\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 504\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'is_leaf'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 505\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mrow\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'feature'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'threshold'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 506\u001b[0m \u001b[0mnode\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'left_child_id'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 507\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
| 220 | + "\u001b[0;31mTypeError\u001b[0m: '>=' not supported between instances of 'numpy.ndarray' and 'str'" |
| 221 | + ] |
| 222 | + } |
| 223 | + ], |
203 | 224 | "source": [ |
204 | 225 | "clf = GLMTreeClassifier(max_depth=3, min_samples_leaf=50, reg_lambda=np.logspace(-5, 5, 10).tolist(),\n", |
205 | 226 | " n_split_grid=20, n_screen_grid=5, n_feature_search=10)\n", |
|
211 | 232 | }, |
212 | 233 | { |
213 | 234 | "cell_type": "code", |
214 | | - "execution_count": null, |
215 | | - "metadata": {}, |
| 235 | + "execution_count": 13, |
| 236 | + "metadata": { |
| 237 | + "ExecuteTime": { |
| 238 | + "end_time": "2021-09-01T17:53:47.559641Z", |
| 239 | + "start_time": "2021-09-01T17:53:47.530587Z" |
| 240 | + } |
| 241 | + }, |
216 | 242 | "outputs": [], |
| 243 | + "source": [ |
| 244 | + "leaf_idx = []\n", |
| 245 | + "for row in train_x.values:\n", |
| 246 | + " node = clf.tree[1]\n", |
| 247 | + " while not node['is_leaf']:\n", |
| 248 | + " if row[node['feature']] <= node['threshold']:\n", |
| 249 | + " node = clf.tree[node['left_child_id']]\n", |
| 250 | + " else:\n", |
| 251 | + " node = clf.tree[node['right_child_id']]\n", |
| 252 | + " leaf_idx.append(node['node_id'])" |
| 253 | + ] |
| 254 | + }, |
| 255 | + { |
| 256 | + "cell_type": "code", |
| 257 | + "execution_count": 14, |
| 258 | + "metadata": { |
| 259 | + "ExecuteTime": { |
| 260 | + "end_time": "2021-09-01T17:54:00.517384Z", |
| 261 | + "start_time": "2021-09-01T17:54:00.511025Z" |
| 262 | + } |
| 263 | + }, |
| 264 | + "outputs": [ |
| 265 | + { |
| 266 | + "data": { |
| 267 | + "text/plain": [ |
| 268 | + "array([[ 57.5644, 20.2196, 2.683 , ..., 11.8193, 26.1085, 217.544 ],\n", |
| 269 | + " [ 27.7996, 14.0561, 2.6839, ..., -9.2172, 38.023 , 97.7341],\n", |
| 270 | + " [ 48.4661, 22.7264, 2.953 , ..., 24.8371, 4.825 , 266.665 ],\n", |
| 271 | + " ...,\n", |
| 272 | + " [ 35.8286, 16.8952, 2.8802, ..., 11.3048, 0.472 , 234.868 ],\n", |
| 273 | + " [ 20.0986, 12.8671, 2.4057, ..., 7.875 , 21.675 , 212.098 ],\n", |
| 274 | + " [ 27.2726, 12.6129, 2.7288, ..., -9.9008, 3.789 , 185.431 ]])" |
| 275 | + ] |
| 276 | + }, |
| 277 | + "execution_count": 14, |
| 278 | + "metadata": {}, |
| 279 | + "output_type": "execute_result" |
| 280 | + } |
| 281 | + ], |
217 | 282 | "source": [] |
218 | 283 | } |
219 | 284 | ], |
|
0 commit comments