Skip to content

Commit 615667f

Browse files
committed
Adding simplified pages for __init__
1 parent 72c86d2 commit 615667f

File tree

6 files changed

+58
-29
lines changed

6 files changed

+58
-29
lines changed

docs/gen_ref_pages.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import mkdocs_gen_files
66
from griffe.collections import ModulesCollection
7+
from griffe.dataclasses import Alias, Module
78
from griffe.loader import GriffeLoader
89

910
TOP_LEVEL_NAME = "pytorch_adapt"
@@ -14,34 +15,63 @@ def remove_pytorch_adapt(x):
1415
return [z for z in list(x.parts) if z != TOP_LEVEL_NAME]
1516

1617

17-
def main():
18-
collection = ModulesCollection()
19-
loader = GriffeLoader(modules_collection=collection)
20-
loader.load_module(Path("src", TOP_LEVEL_NAME))
21-
nav = mkdocs_gen_files.Nav()
18+
def get_init_entries(module_instance, init_entries, prefix=""):
19+
for k, v in module_instance.members.items():
20+
if isinstance(v, Module) and v.is_init_module:
21+
init_key = f"{prefix}.{k}" if prefix else k
22+
init_entries[init_key] = sorted(
23+
[
24+
(name, member.target_path)
25+
for name, member in v.members.items()
26+
if isinstance(member, Alias)
27+
]
28+
)
29+
get_init_entries(v, init_entries, prefix=init_key)
30+
return init_entries
31+
32+
33+
def set_init_pages(module_instance, nav):
34+
init_entries = get_init_entries(module_instance, {}, prefix="")
35+
for k, v in init_entries.items():
36+
k_split = k.split(".")
37+
doc_path = Path(*k_split, "index").with_suffix(".md")
38+
full_doc_path = Path(FOLDER, doc_path)
39+
nav[k_split] = doc_path
40+
to_write = "\n".join([f"- [{x[0]}][{x[1]}]" for x in v])
41+
with mkdocs_gen_files.open(full_doc_path, "w") as fd:
42+
fd.write(to_write)
2243

44+
45+
def set_non_init_pages(collection, nav):
2346
for path in sorted(Path("src").rglob("*.py")):
2447
module_path = path.relative_to("src").with_suffix("")
25-
parts = list(module_path.parts)
48+
parts = tuple(module_path.parts)
2649
if (parts[-1] in ["__init__", "__main__"]) or (
2750
not collection[module_path.parts].has_docstrings
2851
):
2952
continue
3053

3154
doc_path = path.relative_to("src").with_suffix(".md")
32-
full_doc_path = Path(FOLDER, doc_path)
3355
doc_path = Path(*remove_pytorch_adapt(doc_path))
34-
full_doc_path = Path(*remove_pytorch_adapt(full_doc_path))
56+
full_doc_path = Path(FOLDER, doc_path)
3557

3658
for_nav = remove_pytorch_adapt(module_path)
3759
nav[for_nav] = doc_path
3860

3961
with mkdocs_gen_files.open(full_doc_path, "w") as fd:
4062
ident = ".".join(parts)
41-
print("::: " + ident, file=fd)
63+
fd.write(f"::: {ident}")
4264

4365
mkdocs_gen_files.set_edit_path(full_doc_path, path)
4466

67+
68+
def main():
69+
collection = ModulesCollection()
70+
loader = GriffeLoader(modules_collection=collection)
71+
module_instance = loader.load_module(Path("src", TOP_LEVEL_NAME))
72+
nav = mkdocs_gen_files.Nav()
73+
set_init_pages(module_instance, nav)
74+
set_non_init_pages(collection, nav)
4575
with mkdocs_gen_files.open(f"{FOLDER}/SUMMARY.md", "w") as nav_file:
4676
nav_file.writelines(nav.build_literate_nav())
4777

mkdocs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ plugins:
3232
python:
3333
setup_commands:
3434
- import sys
35-
- sys.path.append("src")
35+
- sys.path.insert(0, "src")
3636
rendering:
3737
show_root_toc_entry: false
3838
- section-index

src/pytorch_adapt/hooks/gvb.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,13 @@ def call(self, inputs, losses):
9898
outputs = self.hook(inputs, losses)[0]
9999
strs = c_f.filter(self.hook.out_keys, f"_[a-z]bridge$", ["^src", "^target"])
100100
[src_bridge, target_bridge] = c_f.extract([outputs, inputs], strs)
101-
return outputs, {
102-
f"src_bridge_loss": self.loss_fn(src_bridge),
103-
f"target_bridge_loss": self.loss_fn(target_bridge),
104-
}
101+
return (
102+
outputs,
103+
{
104+
f"src_bridge_loss": self.loss_fn(src_bridge),
105+
f"target_bridge_loss": self.loss_fn(target_bridge),
106+
},
107+
)
105108

106109
def _loss_keys(self):
107110
return [f"src_bridge_loss", f"target_bridge_loss"]

src/pytorch_adapt/hooks/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,10 @@ def __init__(self, loss_names: List[str], out_names: List[str], **kwargs):
4040
def call(self, inputs, losses):
4141
""""""
4242
out_keys = set(self.out_names) - inputs.keys()
43-
return {k: None for k in out_keys}, {
44-
k: c_f.zero_loss() for k in self.loss_names
45-
}
43+
return (
44+
{k: None for k in out_keys},
45+
{k: c_f.zero_loss() for k in self.loss_names},
46+
)
4647

4748
def _loss_keys(self):
4849
""""""

src/pytorch_adapt/hooks/vada.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,13 @@ def call(self, inputs, losses):
2828
)
2929
src_vat_loss = self.loss_fn(src_imgs, src_logits, combined_model)
3030
target_vat_loss = self.loss_fn(target_imgs, target_logits, combined_model)
31-
return outputs, {
32-
"src_vat_loss": src_vat_loss,
33-
"target_vat_loss": target_vat_loss,
34-
}
31+
return (
32+
outputs,
33+
{
34+
"src_vat_loss": src_vat_loss,
35+
"target_vat_loss": target_vat_loss,
36+
},
37+
)
3538

3639
def _loss_keys(self):
3740
return ["src_vat_loss", "target_vat_loss"]

src/pytorch_adapt/validators/class_cluster_validator.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,6 @@ def get_clustering_performance(
103103
centroid_init=None,
104104
feat_normalizer=None,
105105
):
106-
"""
107-
:param feats: N x out numpy vector
108-
:param labels: N numpy vector
109-
:param num_classes: int
110-
:param src_feats
111-
:param pca_size
112-
:return: silhouette and calinski harabasz scores
113-
"""
114106
num_target_feats = feats.shape[0]
115107

116108
if src_feats is not None:

0 commit comments

Comments
 (0)