From e88425bc57c22b89b0b12e7448d579fc3c9956b2 Mon Sep 17 00:00:00 2001 From: Kursat Kaya Date: Mon, 6 Oct 2025 16:47:03 +0200 Subject: [PATCH 1/2] fix PriorDumpDataLoader to correctly handle num_datapoints and max_seq_in_batch --- nanotabpfn/priors.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/nanotabpfn/priors.py b/nanotabpfn/priors.py index 73ca8f3..a31f822 100644 --- a/nanotabpfn/priors.py +++ b/nanotabpfn/priors.py @@ -28,9 +28,10 @@ def __init__(self, filename, num_steps, batch_size, device, starting_index=0): self.num_steps = num_steps self.batch_size = batch_size with h5py.File(self.filename, "r") as f: - self.num_datapoints_max = f['X'].shape[0] + self.num_datapoints_max = f['X'].shape[1] self.max_num_classes = f['max_num_classes'][0] self.problem_type = f['problem_type'][()].decode('utf-8') + self.has_num_datapoints = "num_datapoints" in f self.device = device self.pointer = starting_index @@ -41,8 +42,12 @@ def __iter__(self): end = self.pointer + self.batch_size num_features=self.data['num_features'][self.pointer:end].max() - x = torch.from_numpy(self.data['X'][self.pointer:end,:,:num_features]) - y = torch.from_numpy(self.data['y'][self.pointer:end]) + if self.has_num_datapoints: + max_seq_in_batch = int(f['num_datapoints'][self.pointer:end].max()) + else: + max_seq_in_batch = int(self.num_datapoints_max) + x = torch.from_numpy(self.data['X'][self.pointer:end, :max_seq_in_batch, :num_features]) + y = torch.from_numpy(self.data['y'][self.pointer:end, :max_seq_in_batch]) single_eval_pos = self.data['single_eval_pos'][self.pointer:end] self.pointer += self.batch_size From d077f0d28decfbc106fe6387511586d1da9bc8e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=BCr=C5=9Fat=20Kaya?= <30809805+kursatfelsen@users.noreply.github.com> Date: Mon, 6 Oct 2025 23:31:47 +0200 Subject: [PATCH 2/2] Update nanotabpfn/priors.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- nanotabpfn/priors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nanotabpfn/priors.py b/nanotabpfn/priors.py index a31f822..13d064c 100644 --- a/nanotabpfn/priors.py +++ b/nanotabpfn/priors.py @@ -43,7 +43,7 @@ def __iter__(self): num_features=self.data['num_features'][self.pointer:end].max() if self.has_num_datapoints: - max_seq_in_batch = int(f['num_datapoints'][self.pointer:end].max()) + max_seq_in_batch = int(self.data['num_datapoints'][self.pointer:end].max()) else: max_seq_in_batch = int(self.num_datapoints_max) x = torch.from_numpy(self.data['X'][self.pointer:end, :max_seq_in_batch, :num_features])