Skip to content

Commit faa2ee5

Browse files
committed
first release
1 parent 60c1f1c commit faa2ee5

File tree

4 files changed

+10
-11
lines changed

4 files changed

+10
-11
lines changed

README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# SpidR: Learning Fast and Stable Linguistic Units for Spoken Language Models Without Supervision
22

3-
[📜 [Paper](https://arxiv.org)] [📖 [BibTeX](https://github.com/facebookresearch/spidr/tree/main?tab=readme-ov-file#citation)]
4-
5-
This repository contains the checkpoints and training code for the self-supervised speech models from https://arxiv.org.
3+
This repository contains the checkpoints and training code for the self-supervised speech models
4+
in the SpidR paper (coming soon!).
65

76
## Overview
87

hubconf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from pathlib import Path
55

6-
from torch.hub import _add_to_sys_path
6+
from torch.hub import _add_to_sys_path # noqa: PLC2701
77

88
dependencies = ["torch", "numpy"]
99

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
44

55
[project]
66
name = "spidr"
7-
version = "1.0.0"
7+
version = "0.1.0"
88
description = "Learning Fast and Stable Linguistic Units for Spoken Language Models Without Supervision"
99
readme = "README.md"
1010
license = "CC-BY-NC-4.0"

src/spidr/models/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,19 +183,19 @@ def state_dict_from_dinosr_fairseq_checkpoint(checkpoint: dict[str, Any]) -> Ord
183183
def spidr_base(*, pretrained: bool = True, check_hash: bool = False, progress: bool = True) -> SpidR:
184184
model = SpidR(SpidRConfig())
185185
if pretrained:
186-
url = ""
187-
state_dict = load_state_dict_from_url(url, check_hash=check_hash, progress=progress)
188-
model.load_state_dict(state_dict)
186+
url = "https://dl.fbaipublicfiles.com/shared/devai/assets/models/spidr_base.pt"
187+
checkpoint = load_state_dict_from_url(url, check_hash=check_hash, progress=progress, map_location="cpu")
188+
model.load_state_dict(checkpoint["model"])
189189
model.eval()
190190
return model
191191

192192

193193
def dinosr_base_reproduced(*, pretrained: bool = True, check_hash: bool = False, progress: bool = True) -> DinoSR:
194194
model = DinoSR(DinoSRConfig())
195195
if pretrained:
196-
url = ""
197-
state_dict = load_state_dict_from_url(url, check_hash=check_hash, progress=progress)
198-
model.load_state_dict(state_dict)
196+
url = "https://dl.fbaipublicfiles.com/shared/devai/assets/models/dinosr_base_reproduced.pt"
197+
checkpoint = load_state_dict_from_url(url, check_hash=check_hash, progress=progress, map_location="cpu")
198+
model.load_state_dict(checkpoint["model"])
199199
model.eval()
200200
return model
201201

0 commit comments

Comments
 (0)