Skip to content

Commit 7c116db

Browse files
committed
optimize multi-psi perf when exist multi-partners
1 parent 6026a71 commit 7c116db

File tree

5 files changed

+237
-91
lines changed

5 files changed

+237
-91
lines changed

cpp/wedpr-computing/ppc-psi/src/ecdh-multi-psi/EcdhMultiCache.cpp

Lines changed: 163 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -27,41 +27,80 @@ using namespace bcos;
2727
void MasterCache::addCalculatorCipher(std::string _peerId,
2828
std::map<uint32_t, bcos::bytes>&& _cipherData, uint32_t seq, uint32_t dataBatchCount)
2929
{
30-
bcos::WriteGuard l(x_calculatorCipher);
31-
m_calculatorCipher.insert(_cipherData.begin(), _cipherData.end());
30+
bcos::Guard l(m_mutex);
3231
m_calculatorCipherSeqs.insert(seq);
3332
if (dataBatchCount)
3433
{
3534
m_calculatorDataBatchCount = dataBatchCount;
3635
}
37-
ECDH_MULTI_LOG(INFO) << LOG_DESC(
38-
"addCalculatorCipher: master receive cipher data from calculator")
39-
<< LOG_KV("calculator", _peerId) << printCacheState()
40-
<< LOG_KV("receivedSize", m_calculatorCipherSeqs.size())
41-
<< LOG_KV("calculatorCipherSize", m_calculatorCipher.size())
42-
<< LOG_KV("dataBatchCount", m_calculatorDataBatchCount);
36+
for (auto&& it : _cipherData)
37+
{
38+
updateMasterDataRef(_peerId, std::move(it.second), it.first);
39+
}
40+
// try to merge the
4341
if (m_calculatorDataBatchCount > 0 &&
4442
m_calculatorCipherSeqs.size() == m_calculatorDataBatchCount)
4543
{
4644
ECDH_MULTI_LOG(INFO) << LOG_DESC("The master receive all cipher data from the calculator")
47-
<< LOG_KV("calculatorId", _peerId) << printCacheState();
45+
<< LOG_KV("calculatorId", _peerId)
46+
<< LOG_KV("masterData", m_masterDataRef.size()) << printCacheState();
4847
m_finishedPartners.insert(_peerId);
48+
// try to merge
49+
mergeMasterCipher(_peerId);
50+
}
51+
ECDH_MULTI_LOG(INFO) << LOG_DESC(
52+
"addCalculatorCipher: master receive cipher data from calculator")
53+
<< LOG_KV("calculator", _peerId) << printCacheState()
54+
<< LOG_KV("receivedSize", _cipherData.size())
55+
<< LOG_KV("masterData", m_masterDataRef.size())
56+
<< LOG_KV("dataBatchCount", m_calculatorDataBatchCount);
57+
}
58+
59+
void MasterCache::updateMasterDataRef(
60+
std::string const& _peerId, bcos::bytes&& data, int32_t dataIndex)
61+
{
62+
// not merged case
63+
if (!m_peerMerged)
64+
{
65+
// new data case
66+
if (!m_masterDataRef.count(data))
67+
{
68+
MasterCipherRef ref;
69+
ref.refInfo.insert(_peerId);
70+
ref.updateDataIndex(dataIndex);
71+
m_masterDataRef.insert(std::make_pair(std::move(data), ref));
72+
return;
73+
}
74+
// existed data case
75+
m_masterDataRef[data].refInfo.insert(_peerId);
76+
m_masterDataRef[data].updateDataIndex(dataIndex);
77+
return;
78+
}
79+
80+
// merged case, only record the intersection case
81+
if (m_masterDataRef.count(data))
82+
{
83+
m_masterDataRef[data].refInfo.insert(_peerId);
84+
m_masterDataRef[data].updateDataIndex(dataIndex);
4985
}
5086
}
5187

88+
5289
void MasterCache::addPartnerCipher(std::string _peerId, std::vector<bcos::bytes>&& _cipherData,
5390
uint32_t seq, uint32_t parternerDataCount)
5491
{
55-
bcos::WriteGuard lock(x_partnerToCipher);
56-
if (!m_partnerToCipher.count(_peerId))
92+
bcos::Guard lock(m_mutex);
93+
// record the data-ref-count
94+
for (auto&& data : _cipherData)
5795
{
58-
m_partnerToCipher.insert(std::make_pair(_peerId, std::set<bcos::bytes>()));
96+
updateMasterDataRef(_peerId, std::move(data), -1);
5997
}
60-
m_partnerToCipher[_peerId].insert(_cipherData.begin(), _cipherData.end());
6198
m_partnerCipherSeqs[_peerId].insert(seq);
6299
ECDH_MULTI_LOG(INFO) << LOG_DESC("addPartnerCipher") << LOG_KV("partner", _peerId)
63100
<< LOG_KV("seqSize", m_partnerCipherSeqs.at(_peerId).size())
64-
<< LOG_KV("cipherDataSize", _cipherData.size()) << printCacheState();
101+
<< LOG_KV("cipherDataSize", _cipherData.size())
102+
<< LOG_KV("distinct-partnerDataSize", m_masterDataRef.size())
103+
<< LOG_KV("parternerDataCount", parternerDataCount) << printCacheState();
65104
if (parternerDataCount > 0)
66105
{
67106
m_parternerDataCount.insert(std::make_pair(_peerId, parternerDataCount));
@@ -74,9 +113,43 @@ void MasterCache::addPartnerCipher(std::string _peerId, std::vector<bcos::bytes>
74113
if (m_partnerCipherSeqs[_peerId].size() == expectedCount)
75114
{
76115
m_finishedPartners.insert(_peerId);
116+
// merge when find the send-finished peer
117+
mergeMasterCipher(_peerId);
77118
}
78119
}
79120

121+
void MasterCache::mergeMasterCipher(std::string const& peer)
122+
{
123+
if (m_peerMerged)
124+
{
125+
return;
126+
}
127+
// no need to merge when partnerCount is 1
128+
if (m_peerCount == 1)
129+
{
130+
return;
131+
}
132+
ECDH_MULTI_LOG(INFO) << LOG_DESC("Receive whole data from peer, mergeMasterCipher")
133+
<< LOG_KV("distinct-partnerDataSize-before-merge", m_masterDataRef.size())
134+
<< LOG_KV("finishedPeer", peer) << LOG_KV("partnerCount", m_peerCount);
135+
auto startT = utcSteadyTime();
136+
for (auto it = m_masterDataRef.begin(); it != m_masterDataRef.end();)
137+
{
138+
// not has intersect-element with the finished peer
139+
if (!it->second.refInfo.count(peer))
140+
{
141+
it = m_masterDataRef.erase(it);
142+
continue;
143+
}
144+
it++;
145+
}
146+
m_peerMerged = true;
147+
ECDH_MULTI_LOG(INFO) << LOG_DESC("mergeMasterCipher finished")
148+
<< LOG_KV("distinct-partnerDataSize-after-merge", m_masterDataRef.size())
149+
<< LOG_KV("finishedPeer", peer)
150+
<< LOG_KV("timecost", (utcSteadyTime() - startT));
151+
}
152+
80153
// get the cipher-data intersection: h(x)^a && h(Y)^a
81154
bool MasterCache::tryToIntersection()
82155
{
@@ -87,29 +160,31 @@ bool MasterCache::tryToIntersection()
87160
m_cacheState = CacheState::IntersectionProgressing;
88161

89162
ECDH_MULTI_LOG(INFO) << LOG_DESC("tryToIntersection ") << printCacheState()
90-
<< LOG_KV("calculatorCipher", m_calculatorCipher.size());
163+
<< LOG_KV("masterData", m_masterDataRef.size());
91164
auto startT = utcSteadyTime();
92-
// iterator the calculator cipher to obtain intersection
93-
for (auto&& it : m_calculatorCipher)
165+
// iterator the masterDataRef to obtain intersection
166+
for (auto&& it : m_masterDataRef)
94167
{
95-
bool insersected = true;
96-
for (auto const& partnerIter : m_partnerToCipher)
168+
if (!m_masterDataRef.count(it.first))
97169
{
98-
// not the intersection case
99-
if (!partnerIter.second.count(it.second))
100-
{
101-
insersected = false;
102-
break;
103-
}
170+
continue;
104171
}
105-
if (insersected)
172+
if (m_masterDataRef.at(it.first).refInfo.size() != m_peerCount)
106173
{
107-
m_intersecCipher.emplace_back(std::move(it.second));
108-
m_intersecCipherIndex.emplace_back(it.first);
174+
continue;
109175
}
176+
if (m_masterDataRef.at(it.first).dataIndex == -1)
177+
{
178+
continue;
179+
}
180+
// intersection case
181+
m_intersecCipher.emplace_back(std::move(it.first));
182+
m_intersecCipherIndex.emplace_back(it.second.dataIndex);
110183
}
184+
releaseCache();
111185
m_cacheState = CacheState::Intersectioned;
112186
ECDH_MULTI_LOG(INFO) << LOG_DESC("tryToIntersection success") << printCacheState()
187+
<< LOG_KV("intersectionSize", m_intersecCipher.size())
113188
<< LOG_KV("timecost", (utcSteadyTime() - startT));
114189
return true;
115190
}
@@ -128,7 +203,9 @@ std::vector<std::pair<uint64_t, bcos::bytes>> MasterCache::encryptIntersection(
128203
}
129204
});
130205
// Note: release the m_intersecCipher, make share it not been used after released
131-
releaseAll();
206+
releaseItersection();
207+
ECDH_MULTI_LOG(INFO) << LOG_DESC("encryptIntersection")
208+
<< LOG_KV("cipherCount", cipherData.size()) << printCacheState();
132209
return cipherData;
133210
}
134211

@@ -155,27 +232,28 @@ bool CalculatorCache::tryToFinalize()
155232
return false;
156233
}
157234
auto startT = utcSteadyTime();
158-
ECDH_MULTI_LOG(INFO) << LOG_DESC("tryToFinalize: compute intersection") << printCacheState();
235+
ECDH_MULTI_LOG(INFO) << LOG_DESC("tryToFinalize: compute intersection")
236+
<< LOG_KV("cipherRef", m_cipherRef.size()) << printCacheState();
159237
m_cacheState = CacheState::Finalizing;
160238
// find the intersection
161-
for (auto const& it : m_intersectionCipher)
239+
for (auto const& it : m_cipherRef)
162240
{
163-
if (m_masterCipher.count(it.second))
241+
if (it.second.refCount < 2)
164242
{
165-
auto ret = getPlainDataByIndex(it.first);
166-
if (ret.size() > 0)
167-
{
168-
m_intersectionResult.emplace_back(ret);
169-
}
243+
continue;
244+
}
245+
if (it.second.plainDataIndex > 0)
246+
{
247+
m_intersectionResult.emplace_back(getPlainDataByIndex(it.second.plainDataIndex));
170248
}
171249
}
172250
m_cacheState = CacheState::Finalized;
173-
releaseDataAfterFinalize();
174251
ECDH_MULTI_LOG(INFO) << LOG_DESC("tryToFinalize: compute intersection success")
175-
<< printCacheState()
252+
<< printCacheState() << LOG_KV("cipherRef", m_cipherRef.size())
176253
<< LOG_KV("intersectionSize", m_intersectionResult.size())
177254
<< LOG_KV("timecost", (utcSteadyTime() - startT));
178255

256+
releaseDataAfterFinalize();
179257
ECDH_MULTI_LOG(INFO) << LOG_DESC("tryToFinalize: syncIntersections") << printCacheState();
180258
m_cacheState = CacheState::Syncing;
181259
syncIntersections();
@@ -244,27 +322,65 @@ void CalculatorCache::syncIntersections()
244322
}
245323
}
246324

325+
void CalculatorCache::updateCipherRef(bcos::bytes&& data, int32_t index)
326+
{
327+
// case that receive at least one completed data, only record the intersection data
328+
if (m_receiveAllMasterCipher || m_receiveIntersection)
329+
{
330+
if (!m_cipherRef.count(data))
331+
{
332+
return;
333+
}
334+
}
335+
// new data case
336+
if (!m_cipherRef.count(data))
337+
{
338+
CipherRefDetail cipherRef;
339+
cipherRef.refCount = 1;
340+
cipherRef.updatePlainIndex(index);
341+
m_cipherRef.insert(std::make_pair(std::move(data), std::move(cipherRef)));
342+
return;
343+
}
344+
// existed data case
345+
m_cipherRef[data].refCount += 1;
346+
m_cipherRef[data].updatePlainIndex(index);
347+
}
348+
247349

248350
bool CalculatorCache::appendMasterCipher(
249351
std::vector<bcos::bytes>&& _cipherData, uint32_t seq, uint32_t dataBatchSize)
250352
{
251-
bcos::WriteGuard lock(x_masterCipher);
252-
m_masterCipher.insert(_cipherData.begin(), _cipherData.end());
353+
bcos::Guard lock(m_mutex);
253354
m_receivedMasterCipher.insert(seq);
254355
if (m_masterDataBatchSize == 0 && dataBatchSize > 0)
255356
{
256357
m_masterDataBatchSize = dataBatchSize;
257358
}
359+
if (!m_receiveAllMasterCipher && m_receivedMasterCipher.size() == m_masterDataBatchSize)
360+
{
361+
m_receiveAllMasterCipher = true;
362+
}
363+
for (auto&& it : _cipherData)
364+
{
365+
updateCipherRef(std::move(it), -1);
366+
}
258367
ECDH_MULTI_LOG(INFO) << LOG_DESC("appendMasterCipher") << LOG_KV("dataSize", _cipherData.size())
259-
<< printCacheState();
260-
return m_receivedMasterCipher.size() == m_masterDataBatchSize;
368+
<< LOG_KV("cipherRefSize", m_cipherRef.size()) << printCacheState();
369+
370+
return m_receiveAllMasterCipher;
261371
}
262372

263373
void CalculatorCache::setIntersectionCipher(std::map<uint32_t, bcos::bytes>&& _cipherData)
264374
{
265-
bcos::WriteGuard lock(x_intersectionCipher);
266-
m_intersectionCipher = std::move(_cipherData);
267-
m_receiveIntersection = true;
268375
ECDH_MULTI_LOG(INFO) << LOG_DESC("setIntersectionCipher")
269-
<< LOG_KV("dataSize", m_intersectionCipher.size()) << printCacheState();
376+
<< LOG_KV("dataSize", _cipherData.size())
377+
<< LOG_KV("cipherRefSize", m_cipherRef.size()) << printCacheState();
378+
bcos::Guard lock(m_mutex);
379+
m_receiveIntersection = true;
380+
for (auto&& it : _cipherData)
381+
{
382+
updateCipherRef(std::move(it.second), it.first);
383+
}
384+
ECDH_MULTI_LOG(INFO) << LOG_DESC("setIntersectionCipher finshed")
385+
<< LOG_KV("cipherRefSize", m_cipherRef.size()) << printCacheState();
270386
}

0 commit comments

Comments
 (0)