Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions notebooks/dummy_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import torch

class Identity(torch.nn.Module):
def __init__(self): super().__init__()
def forward(self, x): return x

class MeanAlongDim(torch.nn.Module):
def __init__(self, ax):
super(MeanAlongDim, self).__init__()
Expand Down
37 changes: 23 additions & 14 deletions notebooks/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def _resample_coordinate(
offset = 0 if mode == "edges" else old_step / 2
new_step = old_step / factor
coord = coord - offset
return np.arange(coord.min().item(), coord.max().item()+old_step, step=new_step) + offset
new_coord_end = coord.max().item() + old_step
return np.arange(coord.min().item(), new_coord_end, step=new_step) + offset


def _get_output_array_coordinates(
Expand All @@ -90,7 +91,7 @@ def _get_output_array_coordinates(
output_coords[dim] = _resample_coordinate(src_da[dim], resample_factor[dim], resample_mode)
elif dim in src_da.coords:
# Source array has coordinate but it isn't changing size
output_coords[dim] = src_da[dim].copy()
output_coords[dim] = src_da[dim].copy(deep=True).data
else:
# Source array doesn't have a coordinate on this dim or
# this is a new dim, ignore
Expand Down Expand Up @@ -168,22 +169,28 @@ def predict_on_array(

Overlaps are allowed, in which case the average of all output values is returned.
'''
# TODO input checking
# *_dim args cannot have common axes
s_new = set(new_dim)
s_core = set(core_dim)
s_resample = set(resample_dim)

if s_new & s_core or s_new & s_resample or s_core & s_resample:
raise ValueError("new_dim, core_dim, and resample_dim must be disjoint sets.")

bgen = dataset.X_generator

# Get resample factors
resample_factor = _get_resample_factor(
dataset.X_generator,
bgen,
output_tensor_dim,
resample_dim
)

# Set up output array
output_size = _get_output_array_size(
dataset.X_generator,
output_tensor_dim,
new_dim,
core_dim,
bgen,
output_tensor_dim,
new_dim,
core_dim,
resample_dim
)

Expand All @@ -198,12 +205,14 @@ def predict_on_array(

# Iterate over each batch
for i, batch in enumerate(loader):
out_batch = model(batch).detach().numpy()
input_tensor = batch[0] if isinstance(batch, (list, tuple)) else batch
out_batch = model(input_tensor).detach().numpy()

# Iterate over each example in the batch
# Iterate over each sample in the batch
for ib in range(out_batch.shape[0]):
# Get the slice object associated with this example
old_indexer = dataset.X_generator._batch_selectors.selectors[(i*batch_size)+ib][0]
# Get the slice object associated with this sample
global_index = (i * batch_size) + ib
old_indexer = bgen._batch_selectors.selectors[global_index][0]
# Only index into axes that are resampled, rescaling the bounds
# Perhaps use xbatcher _gen_slices here?
new_indexer = {}
Expand All @@ -213,7 +222,7 @@ def predict_on_array(
int(old_indexer[key].start * resample_factor[key]),
int(old_indexer[key].stop * resample_factor[key])
)

output_da.loc[new_indexer] += out_batch[ib, ...]
output_n.loc[new_indexer] += 1

Expand Down
Loading