|
137 | 137 | "id": "ef85e105", |
138 | 138 | "metadata": { |
139 | 139 | "execution": { |
140 | | - "iopub.execute_input": "2025-12-22T06:39:13.162430Z", |
141 | | - "iopub.status.busy": "2025-12-22T06:39:13.162285Z", |
142 | | - "iopub.status.idle": "2025-12-22T06:39:14.345651Z", |
143 | | - "shell.execute_reply": "2025-12-22T06:39:14.344809Z" |
| 140 | + "iopub.execute_input": "2025-12-22T06:53:14.508507Z", |
| 141 | + "iopub.status.busy": "2025-12-22T06:53:14.508354Z", |
| 142 | + "iopub.status.idle": "2025-12-22T06:53:15.786608Z", |
| 143 | + "shell.execute_reply": "2025-12-22T06:53:15.785666Z" |
144 | 144 | } |
145 | 145 | }, |
146 | 146 | "outputs": [], |
|
219 | 219 | "id": "fa585400", |
220 | 220 | "metadata": { |
221 | 221 | "execution": { |
222 | | - "iopub.execute_input": "2025-12-22T06:39:14.347524Z", |
223 | | - "iopub.status.busy": "2025-12-22T06:39:14.347292Z", |
224 | | - "iopub.status.idle": "2025-12-22T06:39:14.354413Z", |
225 | | - "shell.execute_reply": "2025-12-22T06:39:14.353607Z" |
| 222 | + "iopub.execute_input": "2025-12-22T06:53:15.788481Z", |
| 223 | + "iopub.status.busy": "2025-12-22T06:53:15.788221Z", |
| 224 | + "iopub.status.idle": "2025-12-22T06:53:15.796119Z", |
| 225 | + "shell.execute_reply": "2025-12-22T06:53:15.795231Z" |
226 | 226 | } |
227 | 227 | }, |
228 | 228 | "outputs": [ |
|
303 | 303 | "id": "98b465d4", |
304 | 304 | "metadata": { |
305 | 305 | "execution": { |
306 | | - "iopub.execute_input": "2025-12-22T06:39:14.356021Z", |
307 | | - "iopub.status.busy": "2025-12-22T06:39:14.355844Z", |
308 | | - "iopub.status.idle": "2025-12-22T06:39:14.368147Z", |
309 | | - "shell.execute_reply": "2025-12-22T06:39:14.367369Z" |
| 306 | + "iopub.execute_input": "2025-12-22T06:53:15.797968Z", |
| 307 | + "iopub.status.busy": "2025-12-22T06:53:15.797751Z", |
| 308 | + "iopub.status.idle": "2025-12-22T06:53:15.810927Z", |
| 309 | + "shell.execute_reply": "2025-12-22T06:53:15.809984Z" |
310 | 310 | } |
311 | 311 | }, |
312 | 312 | "outputs": [], |
|
441 | 441 | "id": "6ccb1cfa", |
442 | 442 | "metadata": { |
443 | 443 | "execution": { |
444 | | - "iopub.execute_input": "2025-12-22T06:39:14.369660Z", |
445 | | - "iopub.status.busy": "2025-12-22T06:39:14.369495Z", |
446 | | - "iopub.status.idle": "2025-12-22T06:39:14.376087Z", |
447 | | - "shell.execute_reply": "2025-12-22T06:39:14.375447Z" |
| 444 | + "iopub.execute_input": "2025-12-22T06:53:15.812596Z", |
| 445 | + "iopub.status.busy": "2025-12-22T06:53:15.812412Z", |
| 446 | + "iopub.status.idle": "2025-12-22T06:53:15.819625Z", |
| 447 | + "shell.execute_reply": "2025-12-22T06:53:15.818916Z" |
448 | 448 | } |
449 | 449 | }, |
450 | 450 | "outputs": [ |
|
536 | 536 | "id": "e16a0f66", |
537 | 537 | "metadata": { |
538 | 538 | "execution": { |
539 | | - "iopub.execute_input": "2025-12-22T06:39:14.377908Z", |
540 | | - "iopub.status.busy": "2025-12-22T06:39:14.377753Z", |
541 | | - "iopub.status.idle": "2025-12-22T06:39:14.386118Z", |
542 | | - "shell.execute_reply": "2025-12-22T06:39:14.385472Z" |
| 539 | + "iopub.execute_input": "2025-12-22T06:53:15.821459Z", |
| 540 | + "iopub.status.busy": "2025-12-22T06:53:15.821256Z", |
| 541 | + "iopub.status.idle": "2025-12-22T06:53:15.830300Z", |
| 542 | + "shell.execute_reply": "2025-12-22T06:53:15.829509Z" |
543 | 543 | } |
544 | 544 | }, |
545 | 545 | "outputs": [], |
|
633 | 633 | "id": "c034fcc9", |
634 | 634 | "metadata": { |
635 | 635 | "execution": { |
636 | | - "iopub.execute_input": "2025-12-22T06:39:14.387898Z", |
637 | | - "iopub.status.busy": "2025-12-22T06:39:14.387739Z", |
638 | | - "iopub.status.idle": "2025-12-22T06:39:14.397542Z", |
639 | | - "shell.execute_reply": "2025-12-22T06:39:14.396740Z" |
| 636 | + "iopub.execute_input": "2025-12-22T06:53:15.832058Z", |
| 637 | + "iopub.status.busy": "2025-12-22T06:53:15.831676Z", |
| 638 | + "iopub.status.idle": "2025-12-22T06:53:15.841992Z", |
| 639 | + "shell.execute_reply": "2025-12-22T06:53:15.841212Z" |
640 | 640 | }, |
641 | 641 | "lines_to_next_cell": 2 |
642 | 642 | }, |
|
732 | 732 | "id": "849042f7", |
733 | 733 | "metadata": { |
734 | 734 | "execution": { |
735 | | - "iopub.execute_input": "2025-12-22T06:39:14.399086Z", |
736 | | - "iopub.status.busy": "2025-12-22T06:39:14.398933Z", |
737 | | - "iopub.status.idle": "2025-12-22T06:39:14.404434Z", |
738 | | - "shell.execute_reply": "2025-12-22T06:39:14.403648Z" |
| 735 | + "iopub.execute_input": "2025-12-22T06:53:15.843705Z", |
| 736 | + "iopub.status.busy": "2025-12-22T06:53:15.843548Z", |
| 737 | + "iopub.status.idle": "2025-12-22T06:53:15.850364Z", |
| 738 | + "shell.execute_reply": "2025-12-22T06:53:15.849620Z" |
739 | 739 | } |
740 | 740 | }, |
741 | 741 | "outputs": [], |
|
758 | 758 | " # Create boolean masks for each split\n", |
759 | 759 | " self.split_masks = [assignments == i for i in range(num_splits)]\n", |
760 | 760 | "\n", |
| 761 | + " # Store assignments for merge_split_indices\n", |
| 762 | + " self.register_buffer(\"_assignments\", assignments)\n", |
| 763 | + "\n", |
761 | 764 | " @property\n", |
762 | 765 | " def feature_to_scope(self):\n", |
763 | 766 | " scopes = self.inputs.feature_to_scope\n", |
|
769 | 772 | " @cached\n", |
770 | 773 | " def log_likelihood(self, data, cache=None):\n", |
771 | 774 | " lls = self.inputs.log_likelihood(data, cache=cache)\n", |
772 | | - " return [lls[:, mask, ...] for mask in self.split_masks]" |
| 775 | + " return [lls[:, mask, ...] for mask in self.split_masks]\n", |
| 776 | + "\n", |
| 777 | + "\n", |
| 778 | + " def merge_split_indices(self, *split_indices: Tensor) -> Tensor:\n", |
| 779 | + " batch_size = split_indices[0].shape[0]\n", |
| 780 | + " num_features = self.inputs.out_shape.features\n", |
| 781 | + " # Create output tensor\n", |
| 782 | + " result = torch.zeros(batch_size, num_features, dtype=split_indices[0].dtype, device=split_indices[0].device)\n", |
| 783 | + " # Track position within each split\n", |
| 784 | + " split_positions = [0] * self.num_splits\n", |
| 785 | + " # Scatter indices back to original positions\n", |
| 786 | + " for feature_idx in range(num_features):\n", |
| 787 | + " split_idx = self._assignments[feature_idx].item()\n", |
| 788 | + " pos = split_positions[split_idx]\n", |
| 789 | + " result[:, feature_idx] = split_indices[split_idx][:, pos]\n", |
| 790 | + " split_positions[split_idx] += 1\n", |
| 791 | + " return result" |
773 | 792 | ] |
774 | 793 | }, |
775 | 794 | { |
|
778 | 797 | "id": "f5b66811", |
779 | 798 | "metadata": { |
780 | 799 | "execution": { |
781 | | - "iopub.execute_input": "2025-12-22T06:39:14.405857Z", |
782 | | - "iopub.status.busy": "2025-12-22T06:39:14.405676Z", |
783 | | - "iopub.status.idle": "2025-12-22T06:39:14.572360Z", |
784 | | - "shell.execute_reply": "2025-12-22T06:39:14.571565Z" |
| 800 | + "iopub.execute_input": "2025-12-22T06:53:15.852010Z", |
| 801 | + "iopub.status.busy": "2025-12-22T06:53:15.851829Z", |
| 802 | + "iopub.status.idle": "2025-12-22T06:53:15.856663Z", |
| 803 | + "shell.execute_reply": "2025-12-22T06:53:15.855999Z" |
785 | 804 | } |
786 | 805 | }, |
787 | 806 | "outputs": [ |
788 | 807 | { |
789 | | - "ename": "TypeError", |
790 | | - "evalue": "Can't instantiate abstract class RandomSplit with abstract method merge_split_indices", |
791 | | - "output_type": "error", |
792 | | - "traceback": [ |
793 | | - "\u001b[31m---------------------------------------------------------------------------\u001b[39m", |
794 | | - "\u001b[31mTypeError\u001b[39m Traceback (most recent call last)", |
795 | | - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[8]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;66;03m# Test RandomSplit\u001b[39;00m\n\u001b[32m 2\u001b[39m leaf = Normal(scope=Scope(\u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mrange\u001b[39m(\u001b[32m6\u001b[39m))), out_channels=\u001b[32m2\u001b[39m)\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m split = \u001b[43mRandomSplit\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mleaf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_splits\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mseed\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m123\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 5\u001b[39m data = torch.randn(\u001b[32m5\u001b[39m, \u001b[32m6\u001b[39m)\n\u001b[32m 6\u001b[39m lls = split.log_likelihood(data)\n", |
796 | | - "\u001b[31mTypeError\u001b[39m: Can't instantiate abstract class RandomSplit with abstract method merge_split_indices" |
| 808 | + "name": "stdout", |
| 809 | + "output_type": "stream", |
| 810 | + "text": [ |
| 811 | + "Split 0 shape: torch.Size([5, 5, 2, 1])\n", |
| 812 | + "Split 1 shape: torch.Size([5, 1, 2, 1])\n" |
797 | 813 | ] |
798 | 814 | } |
799 | 815 | ], |
|
831 | 847 | "id": "81fc02a3", |
832 | 848 | "metadata": { |
833 | 849 | "execution": { |
834 | | - "iopub.execute_input": "2025-12-22T06:39:14.574005Z", |
835 | | - "iopub.status.busy": "2025-12-22T06:39:14.573846Z", |
836 | | - "iopub.status.idle": "2025-12-22T06:39:14.580174Z", |
837 | | - "shell.execute_reply": "2025-12-22T06:39:14.579384Z" |
| 850 | + "iopub.execute_input": "2025-12-22T06:53:15.858475Z", |
| 851 | + "iopub.status.busy": "2025-12-22T06:53:15.858322Z", |
| 852 | + "iopub.status.idle": "2025-12-22T06:53:15.864184Z", |
| 853 | + "shell.execute_reply": "2025-12-22T06:53:15.863391Z" |
838 | 854 | }, |
839 | 855 | "lines_to_next_cell": 2 |
840 | 856 | }, |
|
0 commit comments