Skip to content
4 changes: 4 additions & 0 deletions notebooks/dummy_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import torch

class Identity(torch.nn.Module):
def __init__(self): super().__init__()
def forward(self, x): return x

class MeanAlongDim(torch.nn.Module):
def __init__(self, ax):
super(MeanAlongDim, self).__init__()
Expand Down
37 changes: 23 additions & 14 deletions notebooks/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def _resample_coordinate(
offset = 0 if mode == "edges" else old_step / 2
new_step = old_step / factor
coord = coord - offset
return np.arange(coord.min().item(), coord.max().item()+old_step, step=new_step) + offset
new_coord_end = coord.max().item() + old_step
return np.arange(coord.min().item(), new_coord_end, step=new_step) + offset


def _get_output_array_coordinates(
Expand All @@ -90,7 +91,7 @@ def _get_output_array_coordinates(
output_coords[dim] = _resample_coordinate(src_da[dim], resample_factor[dim], resample_mode)
elif dim in src_da.coords:
# Source array has coordinate but it isn't changing size
output_coords[dim] = src_da[dim].copy()
output_coords[dim] = src_da[dim].copy(deep=True).data
else:
# Source array doesn't have a coordinate on this dim or
# this is a new dim, ignore
Expand Down Expand Up @@ -168,22 +169,28 @@ def predict_on_array(

Overlaps are allowed, in which case the average of all output values is returned.
'''
# TODO input checking
# *_dim args cannot have common axes
s_new = set(new_dim)
s_core = set(core_dim)
s_resample = set(resample_dim)

if s_new & s_core or s_new & s_resample or s_core & s_resample:
raise ValueError("new_dim, core_dim, and resample_dim must be disjoint sets.")

bgen = dataset.X_generator

# Get resample factors
resample_factor = _get_resample_factor(
dataset.X_generator,
bgen,
output_tensor_dim,
resample_dim
)

# Set up output array
output_size = _get_output_array_size(
dataset.X_generator,
output_tensor_dim,
new_dim,
core_dim,
bgen,
output_tensor_dim,
new_dim,
core_dim,
resample_dim
)

Expand All @@ -198,12 +205,14 @@ def predict_on_array(

# Iterate over each batch
for i, batch in enumerate(loader):
out_batch = model(batch).detach().numpy()
input_tensor = batch[0] if isinstance(batch, (list, tuple)) else batch
out_batch = model(input_tensor).detach().numpy()

# Iterate over each example in the batch
# Iterate over each sample in the batch
for ib in range(out_batch.shape[0]):
# Get the slice object associated with this example
old_indexer = dataset.X_generator._batch_selectors.selectors[(i*batch_size)+ib][0]
# Get the slice object associated with this sample
global_index = (i * batch_size) + ib
old_indexer = bgen._batch_selectors.selectors[global_index][0]
# Only index into axes that are resampled, rescaling the bounds
# Perhaps use xbatcher _gen_slices here?
new_indexer = {}
Expand All @@ -213,7 +222,7 @@ def predict_on_array(
int(old_indexer[key].start * resample_factor[key]),
int(old_indexer[key].stop * resample_factor[key])
)

output_da.loc[new_indexer] += out_batch[ib, ...]
output_n.loc[new_indexer] += 1

Expand Down
Loading