Skip to content

Commit 843a09f

Browse files
authored
Merge pull request #13 from PolymathicAI/tutorial
fixed tutorial
2 parents cb92c5c + bc45957 commit 843a09f

File tree

2 files changed

+256
-307
lines changed

2 files changed

+256
-307
lines changed

notebooks/tutorial/AstroCLIPTutorial.ipynb

Lines changed: 254 additions & 303 deletions
Large diffs are not rendered by default.

notebooks/tutorial/tutorial_helpers.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,10 @@ def forward(
5252
pos_emb = self.position_embed(pos) # to shape (t, embedding_dim)
5353

5454
x = self.dropout(data_emb + pos_emb)
55-
embeddings = []
5655
for block in self.blocks:
5756
x = block(x)
58-
embeddings.append(x.detach().clone())
5957
x = self.final_layernorm(x)
58+
embedding = x.detach().clone()
6059

6160
preds = self.head(x)
6261
if y is not None:
@@ -66,7 +65,7 @@ def forward(
6665
else:
6766
loss = None
6867

69-
return {"preds": preds, "loss": loss, "embeddings": embeddings}
68+
return {"preds": preds, "loss": loss, "embedding": embedding}
7069

7170
def slice(x, section_length=10, overlap=5):
7271

@@ -79,7 +78,6 @@ def slice(x, section_length=10, overlap=5):
7978

8079
return torch.cat(sections, 1)
8180

82-
8381
def fnc(x):
8482
std, mean = x.std(1, keepdim=True).clip_(0.2), x.mean(1, keepdim=True)
8583
x = (x - mean) / std

0 commit comments

Comments
 (0)