Skip to content

Commit 69f8907

Browse files
committed
wip
1 parent 1895e8a commit 69f8907

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

hyrax_config.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@ phot = "PSF"
4444
mode = "forced"
4545

4646
[data_set.LSDBDataGenerator.dask]
47-
cluster_type = "none"
47+
cluster_type = "LocalCluster" # none or LocalCLuster
4848

49+
# kwargs to pass to `dask.distributed.CLient(**kwargs)`
4950
[data_set.LSDBDataGenerator.dask.LocalCluster]
5051
n_workers = 8
5152
memory_limit = "8GB"

src/uncle_val/learning/models.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,11 @@ class LinearModel(torch.nn.Module):
133133
err_scaler_lg_upper = 5.0
134134
err_scaler_lg_interval = err_scaler_lg_upper - err_scaler_lg_lower
135135

136-
def __init__(self, config, shape) -> None:
136+
def __init__(self, config, data_sample) -> None:
137137
super().__init__()
138-
self.layers = torch.nn.Linear(shape, 1)
138+
self.config = config
139+
input_d = len(data_sample["data"])
140+
self.layers = torch.nn.Linear(input_d, 1)
139141
self.loss = import_by_name(config["model"]["loss_fn"])
140142
self.n_src = config["data_set"]["LSDBDataGenerator"]["n_src"]
141143
if not isinstance(self.n_src, int):

0 commit comments

Comments
 (0)