From 86847f30b9fe5eda96dae61abf3a67f92f7c4cd8 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 8 Apr 2025 15:20:34 -0400 Subject: [PATCH 1/4] force contiguous tensors in torch searchsorted --- bayesflow/utils/tensor_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bayesflow/utils/tensor_utils.py b/bayesflow/utils/tensor_utils.py index 9f161aaca..cd6f6d4ca 100644 --- a/bayesflow/utils/tensor_utils.py +++ b/bayesflow/utils/tensor_utils.py @@ -202,7 +202,9 @@ def searchsorted(sorted_sequence: Tensor, values: Tensor, side: str = "left") -> out_int32 = len(sorted_sequence) <= np.iinfo(np.int32).max - indices = torch.searchsorted(sorted_sequence, values, side=side, out_int32=out_int32) + indices = torch.searchsorted( + sorted_sequence.contiguous(), values.contiguous(), side=side, out_int32=out_int32 + ) return indices case _: From 57a7392d899789b94e084a52cd63c8e4184531c3 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 8 Apr 2025 15:20:44 -0400 Subject: [PATCH 2/4] improve warning message for torch backend --- bayesflow/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bayesflow/__init__.py b/bayesflow/__init__.py index 6f098ea86..e0ce82073 100644 --- a/bayesflow/__init__.py +++ b/bayesflow/__init__.py @@ -40,8 +40,12 @@ def setup(): torch.autograd.set_grad_enabled(False) logging.warning( + "\n" "When using torch backend, we need to disable autograd by default to avoid excessive memory usage. Use\n" + "\n" "with torch.enable_grad():\n" + " ...\n" + "\n" "in contexts where you need gradients (e.g. custom training loops)." ) From d862baf88368d7004b0d2689dce391063b0ede7e Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 8 Apr 2025 15:20:49 -0400 Subject: [PATCH 3/4] run linter --- bayesflow/workflows/basic_workflow.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/bayesflow/workflows/basic_workflow.py b/bayesflow/workflows/basic_workflow.py index e2b0aaa7c..723fa629b 100644 --- a/bayesflow/workflows/basic_workflow.py +++ b/bayesflow/workflows/basic_workflow.py @@ -738,8 +738,14 @@ def fit_online( metric evolution over epochs. """ + import multiprocessing as mp + dataset = OnlineDataset( - simulator=self.simulator, batch_size=batch_size, num_batches=num_batches_per_epoch, adapter=self.adapter + simulator=self.simulator, + batch_size=batch_size, + num_batches=num_batches_per_epoch, + adapter=self.adapter, + workers=mp.cpu_count(), ) return self._fit( From 657d28f71d04c86bed764e360460bb94824ee26a Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 8 Apr 2025 15:24:13 -0400 Subject: [PATCH 4/4] undo using workers in workflow (subject to a different issue) --- bayesflow/workflows/basic_workflow.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/bayesflow/workflows/basic_workflow.py b/bayesflow/workflows/basic_workflow.py index 723fa629b..e2b0aaa7c 100644 --- a/bayesflow/workflows/basic_workflow.py +++ b/bayesflow/workflows/basic_workflow.py @@ -738,14 +738,8 @@ def fit_online( metric evolution over epochs. """ - import multiprocessing as mp - dataset = OnlineDataset( - simulator=self.simulator, - batch_size=batch_size, - num_batches=num_batches_per_epoch, - adapter=self.adapter, - workers=mp.cpu_count(), + simulator=self.simulator, batch_size=batch_size, num_batches=num_batches_per_epoch, adapter=self.adapter ) return self._fit(