Skip to content

Commit 8fa83a5

Browse files
committed
1 parent 4b58671 commit 8fa83a5

25 files changed

+769
-707
lines changed

.doctrees/environment.pickle

0 Bytes
Binary file not shown.

.doctrees/guides/dev_guide.doctree

-7.24 KB
Binary file not shown.
-6 Bytes
Binary file not shown.

.doctrees/nbsphinx/guides/dev_guide.ipynb

Lines changed: 61 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,10 @@
137137
"id": "ef85e105",
138138
"metadata": {
139139
"execution": {
140-
"iopub.execute_input": "2025-12-22T06:39:13.162430Z",
141-
"iopub.status.busy": "2025-12-22T06:39:13.162285Z",
142-
"iopub.status.idle": "2025-12-22T06:39:14.345651Z",
143-
"shell.execute_reply": "2025-12-22T06:39:14.344809Z"
140+
"iopub.execute_input": "2025-12-22T06:53:14.508507Z",
141+
"iopub.status.busy": "2025-12-22T06:53:14.508354Z",
142+
"iopub.status.idle": "2025-12-22T06:53:15.786608Z",
143+
"shell.execute_reply": "2025-12-22T06:53:15.785666Z"
144144
}
145145
},
146146
"outputs": [],
@@ -219,10 +219,10 @@
219219
"id": "fa585400",
220220
"metadata": {
221221
"execution": {
222-
"iopub.execute_input": "2025-12-22T06:39:14.347524Z",
223-
"iopub.status.busy": "2025-12-22T06:39:14.347292Z",
224-
"iopub.status.idle": "2025-12-22T06:39:14.354413Z",
225-
"shell.execute_reply": "2025-12-22T06:39:14.353607Z"
222+
"iopub.execute_input": "2025-12-22T06:53:15.788481Z",
223+
"iopub.status.busy": "2025-12-22T06:53:15.788221Z",
224+
"iopub.status.idle": "2025-12-22T06:53:15.796119Z",
225+
"shell.execute_reply": "2025-12-22T06:53:15.795231Z"
226226
}
227227
},
228228
"outputs": [
@@ -303,10 +303,10 @@
303303
"id": "98b465d4",
304304
"metadata": {
305305
"execution": {
306-
"iopub.execute_input": "2025-12-22T06:39:14.356021Z",
307-
"iopub.status.busy": "2025-12-22T06:39:14.355844Z",
308-
"iopub.status.idle": "2025-12-22T06:39:14.368147Z",
309-
"shell.execute_reply": "2025-12-22T06:39:14.367369Z"
306+
"iopub.execute_input": "2025-12-22T06:53:15.797968Z",
307+
"iopub.status.busy": "2025-12-22T06:53:15.797751Z",
308+
"iopub.status.idle": "2025-12-22T06:53:15.810927Z",
309+
"shell.execute_reply": "2025-12-22T06:53:15.809984Z"
310310
}
311311
},
312312
"outputs": [],
@@ -441,10 +441,10 @@
441441
"id": "6ccb1cfa",
442442
"metadata": {
443443
"execution": {
444-
"iopub.execute_input": "2025-12-22T06:39:14.369660Z",
445-
"iopub.status.busy": "2025-12-22T06:39:14.369495Z",
446-
"iopub.status.idle": "2025-12-22T06:39:14.376087Z",
447-
"shell.execute_reply": "2025-12-22T06:39:14.375447Z"
444+
"iopub.execute_input": "2025-12-22T06:53:15.812596Z",
445+
"iopub.status.busy": "2025-12-22T06:53:15.812412Z",
446+
"iopub.status.idle": "2025-12-22T06:53:15.819625Z",
447+
"shell.execute_reply": "2025-12-22T06:53:15.818916Z"
448448
}
449449
},
450450
"outputs": [
@@ -536,10 +536,10 @@
536536
"id": "e16a0f66",
537537
"metadata": {
538538
"execution": {
539-
"iopub.execute_input": "2025-12-22T06:39:14.377908Z",
540-
"iopub.status.busy": "2025-12-22T06:39:14.377753Z",
541-
"iopub.status.idle": "2025-12-22T06:39:14.386118Z",
542-
"shell.execute_reply": "2025-12-22T06:39:14.385472Z"
539+
"iopub.execute_input": "2025-12-22T06:53:15.821459Z",
540+
"iopub.status.busy": "2025-12-22T06:53:15.821256Z",
541+
"iopub.status.idle": "2025-12-22T06:53:15.830300Z",
542+
"shell.execute_reply": "2025-12-22T06:53:15.829509Z"
543543
}
544544
},
545545
"outputs": [],
@@ -633,10 +633,10 @@
633633
"id": "c034fcc9",
634634
"metadata": {
635635
"execution": {
636-
"iopub.execute_input": "2025-12-22T06:39:14.387898Z",
637-
"iopub.status.busy": "2025-12-22T06:39:14.387739Z",
638-
"iopub.status.idle": "2025-12-22T06:39:14.397542Z",
639-
"shell.execute_reply": "2025-12-22T06:39:14.396740Z"
636+
"iopub.execute_input": "2025-12-22T06:53:15.832058Z",
637+
"iopub.status.busy": "2025-12-22T06:53:15.831676Z",
638+
"iopub.status.idle": "2025-12-22T06:53:15.841992Z",
639+
"shell.execute_reply": "2025-12-22T06:53:15.841212Z"
640640
},
641641
"lines_to_next_cell": 2
642642
},
@@ -732,10 +732,10 @@
732732
"id": "849042f7",
733733
"metadata": {
734734
"execution": {
735-
"iopub.execute_input": "2025-12-22T06:39:14.399086Z",
736-
"iopub.status.busy": "2025-12-22T06:39:14.398933Z",
737-
"iopub.status.idle": "2025-12-22T06:39:14.404434Z",
738-
"shell.execute_reply": "2025-12-22T06:39:14.403648Z"
735+
"iopub.execute_input": "2025-12-22T06:53:15.843705Z",
736+
"iopub.status.busy": "2025-12-22T06:53:15.843548Z",
737+
"iopub.status.idle": "2025-12-22T06:53:15.850364Z",
738+
"shell.execute_reply": "2025-12-22T06:53:15.849620Z"
739739
}
740740
},
741741
"outputs": [],
@@ -758,6 +758,9 @@
758758
" # Create boolean masks for each split\n",
759759
" self.split_masks = [assignments == i for i in range(num_splits)]\n",
760760
"\n",
761+
" # Store assignments for merge_split_indices\n",
762+
" self.register_buffer(\"_assignments\", assignments)\n",
763+
"\n",
761764
" @property\n",
762765
" def feature_to_scope(self):\n",
763766
" scopes = self.inputs.feature_to_scope\n",
@@ -769,7 +772,23 @@
769772
" @cached\n",
770773
" def log_likelihood(self, data, cache=None):\n",
771774
" lls = self.inputs.log_likelihood(data, cache=cache)\n",
772-
" return [lls[:, mask, ...] for mask in self.split_masks]"
775+
" return [lls[:, mask, ...] for mask in self.split_masks]\n",
776+
"\n",
777+
"\n",
778+
" def merge_split_indices(self, *split_indices: Tensor) -> Tensor:\n",
779+
" batch_size = split_indices[0].shape[0]\n",
780+
" num_features = self.inputs.out_shape.features\n",
781+
" # Create output tensor\n",
782+
" result = torch.zeros(batch_size, num_features, dtype=split_indices[0].dtype, device=split_indices[0].device)\n",
783+
" # Track position within each split\n",
784+
" split_positions = [0] * self.num_splits\n",
785+
" # Scatter indices back to original positions\n",
786+
" for feature_idx in range(num_features):\n",
787+
" split_idx = self._assignments[feature_idx].item()\n",
788+
" pos = split_positions[split_idx]\n",
789+
" result[:, feature_idx] = split_indices[split_idx][:, pos]\n",
790+
" split_positions[split_idx] += 1\n",
791+
" return result"
773792
]
774793
},
775794
{
@@ -778,22 +797,19 @@
778797
"id": "f5b66811",
779798
"metadata": {
780799
"execution": {
781-
"iopub.execute_input": "2025-12-22T06:39:14.405857Z",
782-
"iopub.status.busy": "2025-12-22T06:39:14.405676Z",
783-
"iopub.status.idle": "2025-12-22T06:39:14.572360Z",
784-
"shell.execute_reply": "2025-12-22T06:39:14.571565Z"
800+
"iopub.execute_input": "2025-12-22T06:53:15.852010Z",
801+
"iopub.status.busy": "2025-12-22T06:53:15.851829Z",
802+
"iopub.status.idle": "2025-12-22T06:53:15.856663Z",
803+
"shell.execute_reply": "2025-12-22T06:53:15.855999Z"
785804
}
786805
},
787806
"outputs": [
788807
{
789-
"ename": "TypeError",
790-
"evalue": "Can't instantiate abstract class RandomSplit with abstract method merge_split_indices",
791-
"output_type": "error",
792-
"traceback": [
793-
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
794-
"\u001b[31mTypeError\u001b[39m Traceback (most recent call last)",
795-
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[8]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;66;03m# Test RandomSplit\u001b[39;00m\n\u001b[32m 2\u001b[39m leaf = Normal(scope=Scope(\u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mrange\u001b[39m(\u001b[32m6\u001b[39m))), out_channels=\u001b[32m2\u001b[39m)\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m split = \u001b[43mRandomSplit\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mleaf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_splits\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mseed\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m123\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 5\u001b[39m data = torch.randn(\u001b[32m5\u001b[39m, \u001b[32m6\u001b[39m)\n\u001b[32m 6\u001b[39m lls = split.log_likelihood(data)\n",
796-
"\u001b[31mTypeError\u001b[39m: Can't instantiate abstract class RandomSplit with abstract method merge_split_indices"
808+
"name": "stdout",
809+
"output_type": "stream",
810+
"text": [
811+
"Split 0 shape: torch.Size([5, 5, 2, 1])\n",
812+
"Split 1 shape: torch.Size([5, 1, 2, 1])\n"
797813
]
798814
}
799815
],
@@ -831,10 +847,10 @@
831847
"id": "81fc02a3",
832848
"metadata": {
833849
"execution": {
834-
"iopub.execute_input": "2025-12-22T06:39:14.574005Z",
835-
"iopub.status.busy": "2025-12-22T06:39:14.573846Z",
836-
"iopub.status.idle": "2025-12-22T06:39:14.580174Z",
837-
"shell.execute_reply": "2025-12-22T06:39:14.579384Z"
850+
"iopub.execute_input": "2025-12-22T06:53:15.858475Z",
851+
"iopub.status.busy": "2025-12-22T06:53:15.858322Z",
852+
"iopub.status.idle": "2025-12-22T06:53:15.864184Z",
853+
"shell.execute_reply": "2025-12-22T06:53:15.863391Z"
838854
},
839855
"lines_to_next_cell": 2
840856
},

.doctrees/nbsphinx/guides/user_guide.ipynb

Lines changed: 233 additions & 233 deletions
Large diffs are not rendered by default.
-6.77 KB
Loading
-2.01 KB
Loading
-4.19 KB
Loading
-4.01 KB
Loading
-3.36 KB
Loading

0 commit comments

Comments
 (0)