Skip to content

Commit dcdb63a

Browse files
committed
optimize multi-psi perf when exist multi-partners
1 parent 6026a71 commit dcdb63a

File tree

5 files changed

+217
-90
lines changed

5 files changed

+217
-90
lines changed

cpp/wedpr-computing/ppc-psi/src/ecdh-multi-psi/EcdhMultiCache.cpp

Lines changed: 161 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -27,41 +27,81 @@ using namespace bcos;
2727
void MasterCache::addCalculatorCipher(std::string _peerId,
2828
std::map<uint32_t, bcos::bytes>&& _cipherData, uint32_t seq, uint32_t dataBatchCount)
2929
{
30-
bcos::WriteGuard l(x_calculatorCipher);
31-
m_calculatorCipher.insert(_cipherData.begin(), _cipherData.end());
30+
bcos::Guard l(m_mutex);
3231
m_calculatorCipherSeqs.insert(seq);
3332
if (dataBatchCount)
3433
{
3534
m_calculatorDataBatchCount = dataBatchCount;
3635
}
37-
ECDH_MULTI_LOG(INFO) << LOG_DESC(
38-
"addCalculatorCipher: master receive cipher data from calculator")
39-
<< LOG_KV("calculator", _peerId) << printCacheState()
40-
<< LOG_KV("receivedSize", m_calculatorCipherSeqs.size())
41-
<< LOG_KV("calculatorCipherSize", m_calculatorCipher.size())
42-
<< LOG_KV("dataBatchCount", m_calculatorDataBatchCount);
36+
for (auto&& it : _cipherData)
37+
{
38+
updateMasterDataRef(_peerId, std::move(it.second), it.first);
39+
}
40+
// try to merge the
4341
if (m_calculatorDataBatchCount > 0 &&
4442
m_calculatorCipherSeqs.size() == m_calculatorDataBatchCount)
4543
{
4644
ECDH_MULTI_LOG(INFO) << LOG_DESC("The master receive all cipher data from the calculator")
47-
<< LOG_KV("calculatorId", _peerId) << printCacheState();
45+
<< LOG_KV("calculatorId", _peerId)
46+
<< LOG_KV("masterData", m_masterDataRef.size()) << printCacheState();
4847
m_finishedPartners.insert(_peerId);
48+
// try to merge
49+
mergeMasterCipher(_peerId);
4950
}
51+
ECDH_MULTI_LOG(INFO) << LOG_DESC(
52+
"addCalculatorCipher: master receive cipher data from calculator")
53+
<< LOG_KV("calculator", _peerId) << printCacheState()
54+
<< LOG_KV("receivedSize", _cipherData.size())
55+
<< LOG_KV("masterData", m_masterDataRef.size())
56+
<< LOG_KV("dataBatchCount", m_calculatorDataBatchCount);
5057
}
5158

59+
void MasterCache::updateMasterDataRef(
60+
std::string const& _peerId, bcos::bytes&& data, int32_t dataIndex)
61+
{
62+
// not merged case
63+
if (!m_peerMerged)
64+
{
65+
if (!m_masterDataRef.count(data))
66+
{
67+
MasterCipherRef ref;
68+
ref.refInfo.insert(_peerId);
69+
ref.dataIndex = dataIndex;
70+
m_masterDataRef.insert(std::make_pair(std::move(data), ref));
71+
}
72+
else
73+
{
74+
m_masterDataRef[data].refInfo.insert(_peerId);
75+
m_masterDataRef[data].dataIndex = dataIndex;
76+
}
77+
}
78+
else
79+
{
80+
// merged case, only record the intersection case
81+
if (m_masterDataRef.count(data))
82+
{
83+
m_masterDataRef[data].refInfo.insert(_peerId);
84+
m_masterDataRef[data].dataIndex = dataIndex;
85+
}
86+
}
87+
}
88+
89+
5290
void MasterCache::addPartnerCipher(std::string _peerId, std::vector<bcos::bytes>&& _cipherData,
5391
uint32_t seq, uint32_t parternerDataCount)
5492
{
55-
bcos::WriteGuard lock(x_partnerToCipher);
56-
if (!m_partnerToCipher.count(_peerId))
93+
bcos::Guard lock(m_mutex);
94+
// record the data-ref-count
95+
for (auto&& data : _cipherData)
5796
{
58-
m_partnerToCipher.insert(std::make_pair(_peerId, std::set<bcos::bytes>()));
97+
updateMasterDataRef(_peerId, std::move(data), -1);
5998
}
60-
m_partnerToCipher[_peerId].insert(_cipherData.begin(), _cipherData.end());
6199
m_partnerCipherSeqs[_peerId].insert(seq);
62100
ECDH_MULTI_LOG(INFO) << LOG_DESC("addPartnerCipher") << LOG_KV("partner", _peerId)
63101
<< LOG_KV("seqSize", m_partnerCipherSeqs.at(_peerId).size())
64-
<< LOG_KV("cipherDataSize", _cipherData.size()) << printCacheState();
102+
<< LOG_KV("cipherDataSize", _cipherData.size())
103+
<< LOG_KV("distinct-partnerDataSize", m_masterDataRef.size())
104+
<< LOG_KV("parternerDataCount", parternerDataCount) << printCacheState();
65105
if (parternerDataCount > 0)
66106
{
67107
m_parternerDataCount.insert(std::make_pair(_peerId, parternerDataCount));
@@ -74,9 +114,43 @@ void MasterCache::addPartnerCipher(std::string _peerId, std::vector<bcos::bytes>
74114
if (m_partnerCipherSeqs[_peerId].size() == expectedCount)
75115
{
76116
m_finishedPartners.insert(_peerId);
117+
// merge when find the send-finished peer
118+
mergeMasterCipher(_peerId);
77119
}
78120
}
79121

122+
void MasterCache::mergeMasterCipher(std::string const& peer)
123+
{
124+
if (m_peerMerged)
125+
{
126+
return;
127+
}
128+
// no need to merge when partnerCount is 1
129+
if (m_peerCount == 1)
130+
{
131+
return;
132+
}
133+
ECDH_MULTI_LOG(INFO) << LOG_DESC("Receive whole data from peer, mergeMasterCipher")
134+
<< LOG_KV("distinct-partnerDataSize-before-merge", m_masterDataRef.size())
135+
<< LOG_KV("finishedPeer", peer) << LOG_KV("partnerCount", m_peerCount);
136+
auto startT = utcSteadyTime();
137+
for (auto it = m_masterDataRef.begin(); it != m_masterDataRef.end();)
138+
{
139+
// not has intersect-element with the finished peer
140+
if (!it->second.refInfo.count(peer))
141+
{
142+
it = m_masterDataRef.erase(it);
143+
continue;
144+
}
145+
it++;
146+
}
147+
m_peerMerged = true;
148+
ECDH_MULTI_LOG(INFO) << LOG_DESC("mergeMasterCipher finished")
149+
<< LOG_KV("distinct-partnerDataSize-after-merge", m_masterDataRef.size())
150+
<< LOG_KV("finishedPeer", peer)
151+
<< LOG_KV("timecost", (utcSteadyTime() - startT));
152+
}
153+
80154
// get the cipher-data intersection: h(x)^a && h(Y)^a
81155
bool MasterCache::tryToIntersection()
82156
{
@@ -87,29 +161,31 @@ bool MasterCache::tryToIntersection()
87161
m_cacheState = CacheState::IntersectionProgressing;
88162

89163
ECDH_MULTI_LOG(INFO) << LOG_DESC("tryToIntersection ") << printCacheState()
90-
<< LOG_KV("calculatorCipher", m_calculatorCipher.size());
164+
<< LOG_KV("masterData", m_masterDataRef.size());
91165
auto startT = utcSteadyTime();
92166
// iterator the calculator cipher to obtain intersection
93-
for (auto&& it : m_calculatorCipher)
167+
for (auto&& it : m_masterDataRef)
94168
{
95-
bool insersected = true;
96-
for (auto const& partnerIter : m_partnerToCipher)
169+
if (!m_masterDataRef.count(it.first))
97170
{
98-
// not the intersection case
99-
if (!partnerIter.second.count(it.second))
100-
{
101-
insersected = false;
102-
break;
103-
}
171+
continue;
172+
}
173+
if (m_masterDataRef.at(it.first).refInfo.size() != m_peerCount)
174+
{
175+
continue;
104176
}
105-
if (insersected)
177+
if (m_masterDataRef.at(it.first).dataIndex == -1)
106178
{
107-
m_intersecCipher.emplace_back(std::move(it.second));
108-
m_intersecCipherIndex.emplace_back(it.first);
179+
continue;
109180
}
181+
// intersection case
182+
m_intersecCipher.emplace_back(std::move(it.first));
183+
m_intersecCipherIndex.emplace_back(it.second.dataIndex);
110184
}
185+
releaseCache();
111186
m_cacheState = CacheState::Intersectioned;
112187
ECDH_MULTI_LOG(INFO) << LOG_DESC("tryToIntersection success") << printCacheState()
188+
<< LOG_KV("intersectionSize", m_intersecCipher.size())
113189
<< LOG_KV("timecost", (utcSteadyTime() - startT));
114190
return true;
115191
}
@@ -128,7 +204,9 @@ std::vector<std::pair<uint64_t, bcos::bytes>> MasterCache::encryptIntersection(
128204
}
129205
});
130206
// Note: release the m_intersecCipher, make share it not been used after released
131-
releaseAll();
207+
releaseItersection();
208+
ECDH_MULTI_LOG(INFO) << LOG_DESC("encryptIntersection")
209+
<< LOG_KV("cipherCount", cipherData.size()) << printCacheState();
132210
return cipherData;
133211
}
134212

@@ -155,27 +233,28 @@ bool CalculatorCache::tryToFinalize()
155233
return false;
156234
}
157235
auto startT = utcSteadyTime();
158-
ECDH_MULTI_LOG(INFO) << LOG_DESC("tryToFinalize: compute intersection") << printCacheState();
236+
ECDH_MULTI_LOG(INFO) << LOG_DESC("tryToFinalize: compute intersection")
237+
<< LOG_KV("cipherRef", m_cipherRef.size()) << printCacheState();
159238
m_cacheState = CacheState::Finalizing;
160239
// find the intersection
161-
for (auto const& it : m_intersectionCipher)
240+
for (auto const& it : m_cipherRef)
162241
{
163-
if (m_masterCipher.count(it.second))
242+
if (it.second.refCount < 2)
164243
{
165-
auto ret = getPlainDataByIndex(it.first);
166-
if (ret.size() > 0)
167-
{
168-
m_intersectionResult.emplace_back(ret);
169-
}
244+
continue;
245+
}
246+
if (it.second.plainDataIndex > 0)
247+
{
248+
m_intersectionResult.emplace_back(getPlainDataByIndex(it.second.plainDataIndex));
170249
}
171250
}
172251
m_cacheState = CacheState::Finalized;
173-
releaseDataAfterFinalize();
174252
ECDH_MULTI_LOG(INFO) << LOG_DESC("tryToFinalize: compute intersection success")
175-
<< printCacheState()
253+
<< printCacheState() << LOG_KV("cipherRef", m_cipherRef.size())
176254
<< LOG_KV("intersectionSize", m_intersectionResult.size())
177255
<< LOG_KV("timecost", (utcSteadyTime() - startT));
178256

257+
releaseDataAfterFinalize();
179258
ECDH_MULTI_LOG(INFO) << LOG_DESC("tryToFinalize: syncIntersections") << printCacheState();
180259
m_cacheState = CacheState::Syncing;
181260
syncIntersections();
@@ -244,27 +323,63 @@ void CalculatorCache::syncIntersections()
244323
}
245324
}
246325

326+
void CalculatorCache::updateCipherRef(bcos::bytes&& data, int32_t index)
327+
{
328+
// case that receive at least one completed data, only record the intersection data
329+
if (m_receiveAllMasterCipher || m_receiveIntersection)
330+
{
331+
if (!m_cipherRef.count(data))
332+
{
333+
return;
334+
}
335+
}
336+
if (m_cipherRef.count(data))
337+
{
338+
CipherRefDetail cipherRef;
339+
cipherRef.refCount = 1;
340+
cipherRef.plainDataIndex = index;
341+
m_cipherRef.insert(std::make_pair(std::move(data), std::move(cipherRef)));
342+
return;
343+
}
344+
m_cipherRef[data].refCount += 1;
345+
m_cipherRef[data].plainDataIndex = index;
346+
}
347+
247348

248349
bool CalculatorCache::appendMasterCipher(
249350
std::vector<bcos::bytes>&& _cipherData, uint32_t seq, uint32_t dataBatchSize)
250351
{
251-
bcos::WriteGuard lock(x_masterCipher);
252-
m_masterCipher.insert(_cipherData.begin(), _cipherData.end());
352+
bcos::Guard lock(m_mutex);
253353
m_receivedMasterCipher.insert(seq);
354+
if (!m_receiveAllMasterCipher && m_receivedMasterCipher.size() == m_masterDataBatchSize)
355+
{
356+
m_receiveAllMasterCipher = true;
357+
}
358+
for (auto&& it : _cipherData)
359+
{
360+
updateCipherRef(std::move(it), -1);
361+
}
254362
if (m_masterDataBatchSize == 0 && dataBatchSize > 0)
255363
{
256364
m_masterDataBatchSize = dataBatchSize;
257365
}
258366
ECDH_MULTI_LOG(INFO) << LOG_DESC("appendMasterCipher") << LOG_KV("dataSize", _cipherData.size())
259-
<< printCacheState();
260-
return m_receivedMasterCipher.size() == m_masterDataBatchSize;
367+
<< LOG_KV("cipherRefSize", m_cipherRef.size()) << printCacheState();
368+
369+
return m_receiveAllMasterCipher;
261370
}
262371

263372
void CalculatorCache::setIntersectionCipher(std::map<uint32_t, bcos::bytes>&& _cipherData)
264373
{
265-
bcos::WriteGuard lock(x_intersectionCipher);
266-
m_intersectionCipher = std::move(_cipherData);
267-
m_receiveIntersection = true;
268374
ECDH_MULTI_LOG(INFO) << LOG_DESC("setIntersectionCipher")
269-
<< LOG_KV("dataSize", m_intersectionCipher.size()) << printCacheState();
375+
<< LOG_KV("dataSize", _cipherData.size())
376+
<< LOG_KV("cipherRefSize", m_cipherRef.size()) << printCacheState();
377+
bcos::Guard lock(m_mutex);
378+
m_receiveIntersection = true;
379+
for (auto&& it : _cipherData)
380+
{
381+
updateCipherRef(std::move(it.second), it.first);
382+
}
383+
ECDH_MULTI_LOG(INFO) << LOG_DESC("setIntersectionCipher finshed")
384+
<< LOG_KV("cipherRefSize", m_cipherRef.size()) << printCacheState();
270385
}

0 commit comments

Comments
 (0)