Skip to content

Commit 91cf991

Browse files
cat concatenates ReplicatedSharedTensor (#326)
* `cat` concatenates `ReplicatedSharedTensor` - refer to issue #324 - `cat` can concatenates `ShareTensor`s, but it is not workable for `ReplicatedSharedTensor`. - added `cat_replicatedShare_tensor` in *apy.py* and *static.py* for the concatenation of `ReplicatedSharedTensor`. * Saving a variable and formatting Format The following commands have been run: - `python -m black tensor/static.py` - `python -m black api.py` - `isort tensor/static.py` Save number of replicated shares as a variable - Q: Could you save the len(shares[0].shares) in a variable and then use it here and on line 140? [x]*done* Amended docstring - `shares` is a tuple of `ReplicatedSharedTensor`s - the empty blank is deleted in `cat_replicatedShare_tensor` * For Python3.7 that doesn't support `math.prod` - python 3.7 cannot support `math.prod`, only python 3.8+ can. * Remove python3.7 * Remove tutorial python3.7 Co-authored-by: George Muraru <murarugeorgec@gmail.com>
1 parent 41da2eb commit 91cf991

File tree

5 files changed

+47
-9
lines changed

5 files changed

+47
-9
lines changed

.github/workflows/tutorials.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
strategy:
2323
max-parallel: 3
2424
matrix:
25-
python-version: [3.7, 3.8, 3.9]
25+
python-version: [3.8, 3.9]
2626

2727
steps:
2828
- uses: actions/checkout@v2

CONTRIBUTING.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ If you are new to the project and want to get into the code, we recommend pickin
6565
Before you get started you will need a few things installed depending on your operating system.
6666

6767
- OS Package Manager
68-
- Python 3.7+
68+
- Python 3.8+
6969
- git
7070

7171
### OSes
@@ -117,7 +117,7 @@ $ brew install git
117117

118118
## Python Versions
119119

120-
This project supports Python 3.7+, however, if you are contributing it can help to be able to switch between python versions to fix issues or bugs that relate to a specific python version. Depending on your operating system there are a number of ways to install different versions of python however one of the easiest is with the `pyenv` tool. Additionally, as we will be frequently be installing and changing python packages for this project we should isolate it from your system python and other projects you have using a virtualenv.
120+
This project supports Python 3.8+, however, if you are contributing it can help to be able to switch between python versions to fix issues or bugs that relate to a specific python version. Depending on your operating system there are a number of ways to install different versions of python however one of the easiest is with the `pyenv` tool. Additionally, as we will be frequently be installing and changing python packages for this project we should isolate it from your system python and other projects you have using a virtualenv.
121121

122122
### MacOS
123123

@@ -152,10 +152,10 @@ $ pyenv install --list | grep 3.9
152152
3.9.4
153153
```
154154

155-
Wow, there are lots of options, lets install 3.7.
155+
Wow, there are lots of options, lets install 3.8.
156156

157157
```
158-
$ pyenv install 3.7.9
158+
$ pyenv install 3.8.0
159159
```
160160

161161
Now, lets see what versions are installed:
@@ -461,7 +461,7 @@ $ pydocstyle .
461461

462462
### Imports Formatting
463463

464-
We use isort to automatically format the python imports.
464+
We use isort to automatically format the python imports.
465465
Run isort manually like this:
466466

467467
```

src/sympc/api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,10 @@
262262
"sympc.tensor.replicatedshare_tensor.ReplicatedSharedTensor.repeat",
263263
"sympc.tensor.replicatedshare_tensor.ReplicatedSharedTensor",
264264
),
265+
(
266+
"sympc.tensor.static.cat_replicatedShare_tensor",
267+
"sympc.tensor.replicatedshare_tensor.ReplicatedSharedTensor",
268+
),
265269
]
266270

267271
allowed_external_attrs = [

src/sympc/protocol/falcon/falcon.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ def private_compare(x: List[MPCTensor], r: torch.Tensor) -> MPCTensor:
545545
c[i] = u[i] + 1 + w
546546
w += x[i] ^ r_i
547547

548-
d = m * math.prod(c)
548+
d = m * math.prod(c)
549549

550550
d_val = d.reconstruct(decode=False) # plaintext d.
551551
d_val[d_val != 0] = 1 # making all non zero values as 1.

src/sympc/tensor/static.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
import numpy as np
1616
import torch
1717

18+
import sympc.protocol as protocol
1819
from sympc.session import get_session
1920
from sympc.tensor.mpc_tensor import MPCTensor
21+
from sympc.tensor.replicatedshare_tensor import ReplicatedSharedTensor
2022
from sympc.tensor.share_tensor import ShareTensor
2123
from sympc.utils import parallel_execution
2224

@@ -90,8 +92,12 @@ def cat(tensors: List, dim: int = 0) -> MPCTensor:
9092
)
9193
)
9294

93-
stack_shares = parallel_execution(cat_share_tensor, session.parties)(args)
94-
from sympc.tensor import MPCTensor
95+
if isinstance(session.protocol, protocol.FSS):
96+
stack_shares = parallel_execution(cat_share_tensor, session.parties)(args)
97+
elif isinstance(session.protocol, protocol.Falcon):
98+
stack_shares = parallel_execution(cat_replicatedShare_tensor, session.parties)(
99+
args
100+
)
95101

96102
expected_shape = torch.cat(
97103
[torch.empty(each_tensor.shape) for each_tensor in tensors], dim=dim
@@ -118,6 +124,34 @@ def cat_share_tensor(session_uuid_str: str, *shares: Tuple[ShareTensor]) -> Shar
118124
return result
119125

120126

127+
def cat_replicatedShare_tensor(
128+
session_uuid_str: str, *shares: Tuple[ReplicatedSharedTensor]
129+
) -> ReplicatedSharedTensor:
130+
"""Helper method that performs torch.cat on the replicated shares of the Tensors.
131+
132+
Args:
133+
session_uuid_str (str): UUID to identify the session on each party side.
134+
shares (Tuple[ReplicatedSharedTensor]): Replicated shares of the tensors to be concatenated.
135+
136+
Returns:
137+
ReplicatedSharedTensor: Respective replicated shares after concatenation
138+
"""
139+
session = get_session(session_uuid_str)
140+
result = ReplicatedSharedTensor(
141+
session_uuid=UUID(session_uuid_str), config=session.config
142+
)
143+
144+
num_shares = len(shares[0].shares)
145+
146+
cat_result = [torch.tensor([]).type(torch.LongTensor) for _ in range(num_shares)]
147+
for share in shares:
148+
for i in range(num_shares):
149+
cat_result[i] = torch.cat([cat_result[i], share.shares[i]])
150+
151+
result.shares = cat_result
152+
return result
153+
154+
121155
def helper_argmax(
122156
x: MPCTensor,
123157
dim: Optional[Union[int, Tuple[int]]] = None,

0 commit comments

Comments
 (0)