Skip to content

Commit c5ef402

Browse files
committed
address comments
Signed-off-by: Nitish Bharambe <[email protected]>
1 parent c934a5c commit c5ef402

File tree

1 file changed

+36
-53
lines changed

1 file changed

+36
-53
lines changed

docs/examples/arrow_example.ipynb

Lines changed: 36 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
},
2222
{
2323
"cell_type": "code",
24-
"execution_count": 15,
24+
"execution_count": 1,
2525
"metadata": {},
2626
"outputs": [],
2727
"source": [
@@ -44,16 +44,6 @@
4444
"import numpy as np"
4545
]
4646
},
47-
{
48-
"cell_type": "code",
49-
"execution_count": 16,
50-
"metadata": {},
51-
"outputs": [],
52-
"source": [
53-
"# A constant showing error message\n",
54-
"ZERO_COPY_ERROR_MSG = \"Zero-copy conversion requested, but the data types do not match.\""
55-
]
56-
},
5747
{
5848
"cell_type": "markdown",
5949
"metadata": {},
@@ -93,7 +83,7 @@
9383
},
9484
{
9585
"cell_type": "code",
96-
"execution_count": 17,
86+
"execution_count": 2,
9787
"metadata": {},
9888
"outputs": [
9989
{
@@ -142,17 +132,17 @@
142132
},
143133
{
144134
"cell_type": "code",
145-
"execution_count": 18,
135+
"execution_count": 3,
146136
"metadata": {},
147137
"outputs": [
148138
{
149139
"name": "stdout",
150140
"output_type": "stream",
151141
"text": [
152-
"-------node combined asym scehma-------\n",
142+
"-------node asym scehma-------\n",
153143
"id: int32\n",
154144
"u_rated: double\n",
155-
"-------asym load combined asym scehma-------\n",
145+
"-------asym load scehma-------\n",
156146
"id: int32\n",
157147
"node: int32\n",
158148
"status: int8\n",
@@ -179,9 +169,9 @@
179169
" return pa.schema(schemas)\n",
180170
"\n",
181171
"\n",
182-
"print(\"-------node combined asym scehma-------\")\n",
172+
"print(\"-------node asym scehma-------\")\n",
183173
"print(pgm_schema(DatasetType.input, ComponentType.node))\n",
184-
"print(\"-------asym load combined asym scehma-------\")\n",
174+
"print(\"-------asym load scehma-------\")\n",
185175
"print(pgm_schema(DatasetType.input, ComponentType.asym_load))"
186176
]
187177
},
@@ -198,7 +188,7 @@
198188
},
199189
{
200190
"cell_type": "code",
201-
"execution_count": 19,
191+
"execution_count": 4,
202192
"metadata": {},
203193
"outputs": [
204194
{
@@ -212,7 +202,7 @@
212202
"u_rated: [10500,10500,10500]"
213203
]
214204
},
215-
"execution_count": 19,
205+
"execution_count": 4,
216206
"metadata": {},
217207
"output_type": "execute_result"
218208
}
@@ -272,13 +262,14 @@
272262
"Similar approach be adopted by the user to convert to row based data.\n",
273263
"\n",
274264
"```{note}\n",
275-
"The option of `zero_copy_only` in the function below is added in this demo to verify no copies are made. Its usage is not mandatory to do zero copy conversion.\n",
265+
"The option of `zero_copy_only` in the function below and assert for correct dtype is added in this demo to verify no copies are made. \n",
266+
"Its usage is not mandatory to do zero copy conversion.\n",
276267
"```"
277268
]
278269
},
279270
{
280271
"cell_type": "code",
281-
"execution_count": 20,
272+
"execution_count": 5,
282273
"metadata": {},
283274
"outputs": [
284275
{
@@ -287,27 +278,24 @@
287278
"{'id': array([1, 2, 3]), 'u_rated': array([10500., 10500., 10500.])}"
288279
]
289280
},
290-
"execution_count": 20,
281+
"execution_count": 5,
291282
"metadata": {},
292283
"output_type": "execute_result"
293284
}
294285
],
295286
"source": [
296-
"def arrow_to_numpy(\n",
297-
" data: pa.RecordBatch, dataset_type: DatasetType, component_type: ComponentType, zero_copy_only: bool = False\n",
298-
") -> np.ndarray:\n",
287+
"def arrow_to_numpy(data: pa.RecordBatch, dataset_type: DatasetType, component_type: ComponentType) -> np.ndarray:\n",
299288
" \"\"\"Convert Arrow data to NumPy data.\"\"\"\n",
300289
" result = {}\n",
301290
" result_dtype = power_grid_meta_data[dataset_type][component_type].dtype\n",
302291
" for name, column in zip(data.column_names, data.columns):\n",
303-
" column_data = column.to_numpy(zero_copy_only=zero_copy_only)\n",
304-
" if zero_copy_only and column_data.dtype != result_dtype[name]:\n",
305-
" raise ValueError(ZERO_COPY_ERROR_MSG)\n",
292+
" column_data = column.to_numpy(zero_copy_only=True)\n",
293+
" assert column_data.dtype == result_dtype[name]\n",
306294
" result[name] = column_data.astype(dtype=result_dtype[name], copy=False)\n",
307295
" return result\n",
308296
"\n",
309297
"\n",
310-
"node_input = arrow_to_numpy(nodes, DatasetType.input, ComponentType.node, zero_copy_only=True)\n",
298+
"node_input = arrow_to_numpy(nodes, DatasetType.input, ComponentType.node)\n",
311299
"line_input = arrow_to_numpy(lines, DatasetType.input, ComponentType.line)\n",
312300
"source_input = arrow_to_numpy(sources, DatasetType.input, ComponentType.source)\n",
313301
"sym_load_input = arrow_to_numpy(sym_loads, DatasetType.input, ComponentType.sym_load)\n",
@@ -324,7 +312,7 @@
324312
},
325313
{
326314
"cell_type": "code",
327-
"execution_count": 21,
315+
"execution_count": 6,
328316
"metadata": {},
329317
"outputs": [],
330318
"source": [
@@ -338,7 +326,7 @@
338326
},
339327
{
340328
"cell_type": "code",
341-
"execution_count": 22,
329+
"execution_count": 7,
342330
"metadata": {},
343331
"outputs": [],
344332
"source": [
@@ -361,7 +349,7 @@
361349
},
362350
{
363351
"cell_type": "code",
364-
"execution_count": 23,
352+
"execution_count": null,
365353
"metadata": {},
366354
"outputs": [
367355
{
@@ -473,7 +461,7 @@
473461
},
474462
{
475463
"cell_type": "code",
476-
"execution_count": 24,
464+
"execution_count": 9,
477465
"metadata": {},
478466
"outputs": [
479467
{
@@ -497,7 +485,7 @@
497485
"q: [-3299418.661306348,-0.5000000701801947,-1.4999998507078594]"
498486
]
499487
},
500-
"execution_count": 24,
488+
"execution_count": 9,
501489
"metadata": {},
502490
"output_type": "execute_result"
503491
}
@@ -536,7 +524,7 @@
536524
},
537525
{
538526
"cell_type": "code",
539-
"execution_count": 25,
527+
"execution_count": 10,
540528
"metadata": {},
541529
"outputs": [
542530
{
@@ -560,7 +548,7 @@
560548
"q_specified: [[0.5,1500,0.1],[1.5,2.5,1500]]"
561549
]
562550
},
563-
"execution_count": 25,
551+
"execution_count": 10,
564552
"metadata": {},
565553
"output_type": "execute_result"
566554
}
@@ -584,7 +572,7 @@
584572
},
585573
{
586574
"cell_type": "code",
587-
"execution_count": 26,
575+
"execution_count": 11,
588576
"metadata": {},
589577
"outputs": [
590578
{
@@ -600,15 +588,13 @@
600588
" [1.5e+00, 2.5e+00, 1.5e+03]])}"
601589
]
602590
},
603-
"execution_count": 26,
591+
"execution_count": 11,
604592
"metadata": {},
605593
"output_type": "execute_result"
606594
}
607595
],
608596
"source": [
609-
"def arrow_to_numpy_asym(\n",
610-
" data: pa.RecordBatch, dataset_type: DatasetType, component_type: ComponentType, zero_copy_only: bool = False\n",
611-
") -> np.ndarray:\n",
597+
"def arrow_to_numpy_asym(data: pa.RecordBatch, dataset_type: DatasetType, component_type: ComponentType) -> np.ndarray:\n",
612598
" \"\"\"Convert asymmetric Arrow data to NumPy data.\n",
613599
"\n",
614600
" This function is similar to the arrow_to_numpy function, but also supports asymmetric data.\"\"\"\n",
@@ -621,17 +607,15 @@
621607
" dtype = result_dtype[name]\n",
622608
"\n",
623609
" if len(dtype.shape) == 0:\n",
624-
" column_data = data.column(name).to_numpy(zero_copy_only=zero_copy_only)\n",
610+
" column_data = data.column(name).to_numpy(zero_copy_only=True)\n",
625611
" else:\n",
626-
" column_data = data.column(name).flatten().to_numpy(zero_copy_only=zero_copy_only).reshape(-1, 3)\n",
627-
"\n",
628-
" if zero_copy_only and column_data.dtype.base != dtype.base:\n",
629-
" raise ValueError(ZERO_COPY_ERROR_MSG)\n",
612+
" column_data = data.column(name).flatten().to_numpy(zero_copy_only=True).reshape(-1, 3)\n",
613+
" assert column_data.dtype.base == dtype.base\n",
630614
" result[name] = column_data.astype(dtype=dtype.base, copy=False)\n",
631615
" return result\n",
632616
"\n",
633617
"\n",
634-
"asym_load_input = arrow_to_numpy_asym(asym_loads, DatasetType.input, ComponentType.asym_load, zero_copy_only=True)\n",
618+
"asym_load_input = arrow_to_numpy_asym(asym_loads, DatasetType.input, ComponentType.asym_load)\n",
635619
"\n",
636620
"asym_load_input"
637621
]
@@ -645,7 +629,7 @@
645629
},
646630
{
647631
"cell_type": "code",
648-
"execution_count": 27,
632+
"execution_count": 12,
649633
"metadata": {},
650634
"outputs": [
651635
{
@@ -704,9 +688,8 @@
704688
"2 -0.004338 -2.098733 2.090057"
705689
]
706690
},
707-
"execution_count": 27,
708691
"metadata": {},
709-
"output_type": "execute_result"
692+
"output_type": "display_data"
710693
}
711694
],
712695
"source": [
@@ -728,7 +711,7 @@
728711
")\n",
729712
"\n",
730713
"# use pandas to display the results, but beware the data types\n",
731-
"pd.DataFrame(asym_result[ComponentType.node][\"u_angle\"])"
714+
"display(pd.DataFrame(asym_result[ComponentType.node][\"u_angle\"]))"
732715
]
733716
},
734717
{
@@ -740,7 +723,7 @@
740723
},
741724
{
742725
"cell_type": "code",
743-
"execution_count": 28,
726+
"execution_count": 13,
744727
"metadata": {},
745728
"outputs": [
746729
{
@@ -769,7 +752,7 @@
769752
"q: [[-1099806.4185888197,-1098301.0302391076,-1098302.79423175],[-0.499999998516201,-1499.9999999095232,-0.10000001915949493],[-1.5000000216889147,-2.50000006806065,-1500.0000000385737]]"
770753
]
771754
},
772-
"execution_count": 28,
755+
"execution_count": 13,
773756
"metadata": {},
774757
"output_type": "execute_result"
775758
}

0 commit comments

Comments
 (0)