diff --git a/nanotabpfn/priors.py b/nanotabpfn/priors.py index 73ca8f3..13d064c 100644 --- a/nanotabpfn/priors.py +++ b/nanotabpfn/priors.py @@ -28,9 +28,10 @@ def __init__(self, filename, num_steps, batch_size, device, starting_index=0): self.num_steps = num_steps self.batch_size = batch_size with h5py.File(self.filename, "r") as f: - self.num_datapoints_max = f['X'].shape[0] + self.num_datapoints_max = f['X'].shape[1] self.max_num_classes = f['max_num_classes'][0] self.problem_type = f['problem_type'][()].decode('utf-8') + self.has_num_datapoints = "num_datapoints" in f self.device = device self.pointer = starting_index @@ -41,8 +42,12 @@ def __iter__(self): end = self.pointer + self.batch_size num_features=self.data['num_features'][self.pointer:end].max() - x = torch.from_numpy(self.data['X'][self.pointer:end,:,:num_features]) - y = torch.from_numpy(self.data['y'][self.pointer:end]) + if self.has_num_datapoints: + max_seq_in_batch = int(self.data['num_datapoints'][self.pointer:end].max()) + else: + max_seq_in_batch = int(self.num_datapoints_max) + x = torch.from_numpy(self.data['X'][self.pointer:end, :max_seq_in_batch, :num_features]) + y = torch.from_numpy(self.data['y'][self.pointer:end, :max_seq_in_batch]) single_eval_pos = self.data['single_eval_pos'][self.pointer:end] self.pointer += self.batch_size