Skip to content

Commit effbb00

Browse files
authored
Force contiguous searchsorted (#396)
* force contiguous tensors in torch searchsorted * improve warning message for torch backend * run linter * undo using workers in workflow (subject to a different issue)
1 parent bb973d2 commit effbb00

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

bayesflow/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,12 @@ def setup():
4040
torch.autograd.set_grad_enabled(False)
4141

4242
logging.warning(
43+
"\n"
4344
"When using torch backend, we need to disable autograd by default to avoid excessive memory usage. Use\n"
45+
"\n"
4446
"with torch.enable_grad():\n"
47+
" ...\n"
48+
"\n"
4549
"in contexts where you need gradients (e.g. custom training loops)."
4650
)
4751

bayesflow/utils/tensor_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,9 @@ def searchsorted(sorted_sequence: Tensor, values: Tensor, side: str = "left") ->
202202

203203
out_int32 = len(sorted_sequence) <= np.iinfo(np.int32).max
204204

205-
indices = torch.searchsorted(sorted_sequence, values, side=side, out_int32=out_int32)
205+
indices = torch.searchsorted(
206+
sorted_sequence.contiguous(), values.contiguous(), side=side, out_int32=out_int32
207+
)
206208

207209
return indices
208210
case _:

0 commit comments

Comments
 (0)