Skip to content

Commit fec6c38

Browse files
committed
Add InMask and OutMask modules for easier pipeline configuration
1 parent 2a0ce29 commit fec6c38

File tree

13 files changed

+315
-3
lines changed

13 files changed

+315
-3
lines changed

coverage-badge.svg

Lines changed: 1 addition & 1 deletion
Loading
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2+
::: synalinks.src.modules.masking.in_mask
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2+
::: synalinks.src.modules.masking.out_mask
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Masking Modules
2+
3+
- [InMask module](InMask module.md)
4+
- [OutMask module](OutMask module.md)

docs/Synalinks API/Modules API/index.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ if __name__ == "__main__":
5757

5858
---
5959

60+
### Masking Modules
61+
62+
- [InMask module](Masking Modules/InMask module.md)
63+
- [OutMask module](Masking Modules/OutMask module.md)
64+
65+
---
66+
6067
### Merging Modules
6168

6269
- [Concat module](Merging Modules/Concat module.md)

mkdocs.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ nav:
111111
- Synalinks API/Modules API/Retrievers Modules/index.md
112112
- Synalinks API/Modules API/Retrievers Modules/EntityRetriever module.md
113113
- Synalinks API/Modules API/Retrievers Modules/TripletRetriever module.md
114+
- Masking Modules:
115+
- Synalinks API/Modules API/Masking Modules/index.md
116+
- Synalinks API/Modules API/Masking Modules/InMask module.md
117+
- Synalinks API/Modules API/Masking Modules/OutMask module.md
114118
- Merging Modules:
115119
- Synalinks API/Modules API/Merging Modules/index.md
116120
- Synalinks API/Modules API/Merging Modules/Concat module.md
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from synalinks.src.modules.masking.in_mask import InMask
2+
from synalinks.src.modules.masking.out_mask import OutMask
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# License Apache 2.0: (c) 2025 Yoan Sallami (Synalinks Team)
2+
3+
from synalinks.src import tree
4+
from synalinks.src.api_export import synalinks_export
5+
from synalinks.src.modules.module import Module
6+
7+
8+
@synalinks_export(
9+
[
10+
"synalinks.InMask",
11+
"synalinks.modules.InMask",
12+
]
13+
)
14+
class InMask(Module):
15+
"""A module to keep specific fields of the given data models
16+
17+
Example:
18+
19+
```python
20+
import synalinks
21+
import asyncio
22+
23+
language_model = synalinks.LanguageModel(
24+
model="ollama/mistral",
25+
)
26+
27+
class Document(synalinks.DataModel):
28+
title: str = synalinks.Field(
29+
description="The title of the document",
30+
)
31+
text: str = synalinks.Field(
32+
description="The content of the document",
33+
)
34+
35+
class Summary(synalinks.DataModel):
36+
summary: str = synalinks.Field(
37+
description="the concise summary of the document",
38+
)
39+
40+
async def main():
41+
inputs = Input(data_model=Document)
42+
summary = synalinks.ChainOfThought(
43+
data_model=Summary,
44+
language_model=language_model,
45+
)(inputs)
46+
masked_summary = synalinks.InMask(
47+
# remove the thinking field from the chain of thought
48+
# by keeping only the summary
49+
mask=["summary"],
50+
)(summary)
51+
52+
program = Program(
53+
inputs=inputs,
54+
outputs=masked_summary,
55+
name="summary_generator",
56+
description="Generate a summary from a document",
57+
)
58+
59+
```
60+
61+
Args:
62+
mask (list): The list of keys to keep.
63+
name (str): Optional. The name of the module.
64+
description (str): Optional. The description of the module.
65+
trainable (bool): Whether the module's variables should be trainable.
66+
"""
67+
68+
def __init__(
69+
self,
70+
mask=None,
71+
name=None,
72+
description=None,
73+
trainable=False,
74+
):
75+
if not mask or not isinstance(mask, list):
76+
raise ValueError("`mask` parameter should be a list of fields to keep")
77+
super().__init__(
78+
name=name,
79+
description=description,
80+
)
81+
self.mask = mask
82+
83+
async def call(self, inputs):
84+
outputs = tree.map_structure(
85+
lambda x: x.in_mask(mask=self.mask),
86+
inputs,
87+
)
88+
return outputs
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# License Apache 2.0: (c) 2025 Yoan Sallami (Synalinks Team)
2+
3+
from synalinks.src import testing
4+
from synalinks.src.backend import DataModel
5+
from synalinks.src.modules import Input
6+
from synalinks.src.programs import Program
7+
from synalinks.src.modules.masking.in_mask import InMask
8+
9+
10+
class Document(DataModel):
11+
title: str
12+
text: str
13+
14+
15+
class InMaskTest(testing.TestCase):
16+
17+
async def test_in_mask_single_data_model(self):
18+
19+
inputs = Input(data_model=Document)
20+
21+
outputs = await InMask(
22+
mask=["text"],
23+
)(inputs)
24+
25+
program = Program(
26+
inputs=inputs,
27+
outputs=outputs,
28+
name="masking_program",
29+
description="A program to keep fields",
30+
)
31+
32+
doc = Document(title="Test document", text="Hello world")
33+
34+
result = await program(doc)
35+
36+
self.assertTrue(len(result.keys()) == 1)
37+
38+
async def test_in_mask_multiple_data_models(self):
39+
40+
inputs = [Input(data_model=Document), Input(data_model=Document)]
41+
42+
outputs = await InMask(
43+
mask=["text"],
44+
)(inputs)
45+
46+
program = Program(
47+
inputs=inputs,
48+
outputs=outputs,
49+
name="masking_program",
50+
description="A program to keep fields",
51+
)
52+
53+
doc1 = Document(title="Test document 1", text="Hello world")
54+
doc2 = Document(title="Test document 2", text="Hello world")
55+
56+
results = await program([doc1, doc2])
57+
58+
self.assertTrue(len(results[0].keys()) == 1)
59+
self.assertTrue(len(results[1].keys()) == 1)
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# License Apache 2.0: (c) 2025 Yoan Sallami (Synalinks Team)
2+
3+
from synalinks.src import tree
4+
from synalinks.src.api_export import synalinks_export
5+
from synalinks.src.modules.module import Module
6+
7+
8+
@synalinks_export(
9+
[
10+
"synalinks.OutMask",
11+
"synalinks.modules.OutMask",
12+
]
13+
)
14+
class OutMask(Module):
15+
"""A module to remove specific fields of the given data models
16+
17+
Example:
18+
19+
```python
20+
import synalinks
21+
import asyncio
22+
23+
language_model = synalinks.LanguageModel(
24+
model="ollama/mistral",
25+
)
26+
27+
class Document(synalinks.DataModel):
28+
title: str = synalinks.Field(
29+
description="The title of the document",
30+
)
31+
text: str = synalinks.Field(
32+
description="The content of the document",
33+
)
34+
35+
class Summary(synalinks.DataModel):
36+
summary: str = synalinks.Field(
37+
description="the concise summary of the document",
38+
)
39+
40+
async def main():
41+
inputs = Input(data_model=Document)
42+
summary = synalinks.ChainOfThought(
43+
data_model=Summary,
44+
language_model=language_model,
45+
)(inputs)
46+
masked_summary = synalinks.OutMask(
47+
# remove the thinking field from the chain of thought
48+
mask=["thinking"],
49+
)(summary)
50+
51+
program = Program(
52+
inputs=inputs,
53+
outputs=masked_summary,
54+
name="summary_generator",
55+
description="Generate a summary from a document",
56+
)
57+
```
58+
59+
Args:
60+
mask (list): The list of keys to remove.
61+
name (str): Optional. The name of the module.
62+
description (str): Optional. The description of the module.
63+
trainable (bool): Whether the module's variables should be trainable.
64+
"""
65+
66+
def __init__(
67+
self,
68+
mask=None,
69+
name=None,
70+
description=None,
71+
trainable=False,
72+
):
73+
if not mask or not isinstance(mask, list):
74+
raise ValueError("`mask` parameter should be a list of fields to remove")
75+
super().__init__(
76+
name=name,
77+
description=description,
78+
)
79+
self.mask = mask
80+
81+
async def call(self, inputs):
82+
outputs = tree.map_structure(
83+
lambda x: x.out_mask(mask=self.mask),
84+
inputs,
85+
)
86+
return outputs

0 commit comments

Comments
 (0)