@@ -27,7 +27,7 @@ using namespace bcos;
2727void MasterCache::addCalculatorCipher (std::string _peerId,
2828 std::map<uint32_t , bcos::bytes>&& _cipherData, uint32_t seq, uint32_t dataBatchCount)
2929{
30- bcos::WriteGuard lock (x_calculatorCipher);
30+ bcos::WriteGuard l (x_calculatorCipher);
3131 m_calculatorCipher.insert (_cipherData.begin (), _cipherData.end ());
3232 m_calculatorCipherSeqs.insert (seq);
3333 if (dataBatchCount)
@@ -36,35 +36,42 @@ void MasterCache::addCalculatorCipher(std::string _peerId,
3636 }
3737 ECDH_MULTI_LOG (INFO) << LOG_DESC (
3838 " addCalculatorCipher: master receive cipher data from calculator" )
39- << LOG_KV (" calculator" , _peerId)
40- << LOG_KV (" task" , printCacheState (m_taskState))
39+ << LOG_KV (" calculator" , _peerId) << LOG_KV (" task" , printCacheState ())
4140 << LOG_KV (" receivedSize" , m_calculatorCipherSeqs.size ())
4241 << LOG_KV (" dataBatchCount" , m_calculatorDataBatchCount);
4342 if (m_calculatorDataBatchCount > 0 &&
4443 m_calculatorCipherSeqs.size () == m_calculatorDataBatchCount)
4544 {
4645 ECDH_MULTI_LOG (INFO) << LOG_DESC (" The master receive all cipher data from the calculator" )
4746 << LOG_KV (" calculatorId" , _peerId)
48- << LOG_KV (" task" , printCacheState (m_taskState ));
47+ << LOG_KV (" task" , printCacheState ());
4948 m_finishedPartners.insert (_peerId);
5049 }
5150}
5251
5352void MasterCache::addPartnerCipher (std::string _peerId, std::vector<bcos::bytes>&& _cipherData,
54- uint32_t seq, uint32_t needSendTimes )
53+ uint32_t seq, uint32_t parternerDataCount )
5554{
5655 bcos::WriteGuard lock (x_partnerToCipher);
5756 if (!m_partnerToCipher.count (_peerId))
5857 {
59- m_partnerToCipher.insert (std::make_pair (_peerId, std::set ()));
58+ m_partnerToCipher.insert (std::make_pair (_peerId, std::set<bcos::bytes> ()));
6059 }
6160 m_partnerToCipher[_peerId].insert (_cipherData.begin (), _cipherData.end ());
6261 m_partnerCipherSeqs[_peerId].insert (seq);
6362 ECDH_MULTI_LOG (INFO) << LOG_DESC (" addPartnerCipher" ) << LOG_KV (" partner" , _peerId)
6463 << LOG_KV (" seqSize" , m_partnerCipherSeqs.at (_peerId).size ())
65- << LOG_KV (" task" , printCacheState (m_taskState));
66-
67- if (m_partnerCipherSeqs[_peerId].size () == needSendTimes)
64+ << LOG_KV (" task" , printCacheState ());
65+ if (parternerDataCount > 0 )
66+ {
67+ m_parternerDataCount.insert (std::make_pair (_peerId, parternerDataCount));
68+ }
69+ if (!m_parternerDataCount.count (_peerId))
70+ {
71+ return ;
72+ }
73+ auto expectedCount = m_parternerDataCount.at (_peerId);
74+ if (m_partnerCipherSeqs[_peerId].size () == expectedCount)
6875 {
6976 m_finishedPartners.insert (_peerId);
7077 }
@@ -77,10 +84,9 @@ bool MasterCache::tryToIntersection()
7784 {
7885 return false ;
7986 }
80- m_state = CacheState::IntersectionProgressing;
87+ m_cacheState = CacheState::IntersectionProgressing;
8188
82- ECDH_MULTI_LOG (INFO) << LOG_DESC (" tryToIntersection " )
83- << LOG_KV (" task" , printCacheState (m_taskState));
89+ ECDH_MULTI_LOG (INFO) << LOG_DESC (" tryToIntersection " ) << LOG_KV (" task" , printCacheState ());
8490 auto startT = utcSteadyTime ();
8591 // iterator the calculator cipher to obtain intersection
8692 for (auto && it : m_calculatorCipher)
@@ -95,74 +101,98 @@ bool MasterCache::tryToIntersection()
95101 break ;
96102 }
97103 }
98- if (intersection )
104+ if (insersected )
99105 {
100- m_intersecCipher.insert (std::make_pair (it.first , std::move (it.second )));
106+ m_intersecCipher.emplace_back (std::move (it.second ));
107+ m_intersecCipherIndex.emplace_back (it.first );
101108 }
102109 }
103- m_state = CacheState::Intersectioned;
110+ m_cacheState = CacheState::Intersectioned;
104111 ECDH_MULTI_LOG (INFO) << LOG_DESC (" tryToIntersection success" )
105- << LOG_KV (" task" , printCacheState (m_taskState ))
112+ << LOG_KV (" task" , printCacheState ())
106113 << LOG_KV (" timecost" , (utcSteadyTime () - startT));
107114 return true ;
108115}
109116
110- std::vector<bcos::bytes> CalculatorCache::encryptIntersection (bcos::bytes const & randomKey)
117+ std::vector<std::pair<uint64_t , bcos::bytes>> MasterCache::encryptIntersection (
118+ bcos::bytes const & randomKey)
111119{
112120 std::vector<std::pair<uint64_t , bcos::bytes>> cipherData (m_intersecCipher.size ());
113- tbb::parallel_for_each (
114- m_intersecCipher.begin (), m_intersecCipher.end (), [&](auto const & _pair) {
115- auto value = _pair.second ;
116- auto cipherValue = m_config->eccCrypto ()->ecMultiply (value, randomKey);
117- cipherData[i] = std::make_pair (_pair.first , cipherValue);
121+ tbb::parallel_for (
122+ tbb::blocked_range<size_t >(0U , m_intersecCipher.size ()), [&](auto const & range) {
123+ for (auto i = range.begin (); i < range.end (); i++)
124+ {
125+ auto cipherValue =
126+ m_config->eccCrypto ()->ecMultiply (m_intersecCipher[i], randomKey);
127+ cipherData[i] = std::make_pair (m_intersecCipherIndex[i], cipherValue);
128+ }
118129 });
119130 return cipherData;
120131}
121132
122- bcos::bytes CalculatorCache::getPlainDataByIndex (uint64_t index) {}
133+ bcos::bytes CalculatorCache::getPlainDataByIndex (uint64_t index)
134+ {
135+ uint64_t startIndex = 0 ;
136+ uint64_t endIndex = 0 ;
137+ for (auto const & it : m_plainData)
138+ {
139+ endIndex += it->size ();
140+ if (index >= startIndex && index < endIndex)
141+ {
142+ return it->getBytes ((index - startIndex));
143+ }
144+ startIndex += it->size ();
145+ }
146+ return bcos::bytes ();
147+ }
123148
124- void CalculatorCache::tryToFinalize ()
149+ bool CalculatorCache::tryToFinalize ()
125150{
126151 if (!shouldFinalize ())
127152 {
128- return ;
153+ return false ;
129154 }
130155 auto startT = utcSteadyTime ();
131156 ECDH_MULTI_LOG (INFO) << LOG_DESC (" tryToFinalize: compute intersection" )
132157 << printTaskInfo (m_taskState->task ());
133- m_state = CacheState::Finalizing;
158+ m_cacheState = CacheState::Finalizing;
134159 // find the intersection
135160 for (auto const & it : m_intersectionCipher)
136161 {
137162 if (m_masterCipher.count (it.second ))
138163 {
139- m_intersectionResult.emplace_back (getPlainDataByIndex (it.first ));
164+ auto ret = getPlainDataByIndex (it.first );
165+ if (ret.size () > 0 )
166+ {
167+ m_intersectionResult.emplace_back (ret);
168+ }
140169 }
141170 }
142- m_state = CacheState::Finalized;
171+ m_cacheState = CacheState::Finalized;
143172 ECDH_MULTI_LOG (INFO) << LOG_DESC (" tryToFinalize: compute intersection success" )
144173 << printTaskInfo (m_taskState->task ())
145- << LOG_KV (" intersectionSize" , m_intersectionResult.size ());
146- << LOG_KV (" timecost" , (utcSteadyTime () - startT));
174+ << LOG_KV (" intersectionSize" , m_intersectionResult.size ())
175+ << LOG_KV (" timecost" , (utcSteadyTime () - startT));
147176
148177 ECDH_MULTI_LOG (INFO) << LOG_DESC (" tryToFinalize: syncIntersections" )
149178 << printTaskInfo (m_taskState->task ());
150- m_state = CacheState::Syncing;
179+ m_cacheState = CacheState::Syncing;
151180 syncIntersections ();
152- m_state = CacheState::Synced;
181+ m_cacheState = CacheState::Synced;
153182
154- m_state = CacheState::StoreProgressing;
183+ m_cacheState = CacheState::StoreProgressing;
155184 m_taskState->storePSIResult (m_config->dataResourceLoader (), m_intersectionResult);
156- m_state = CacheState::Stored;
185+ m_cacheState = CacheState::Stored;
157186 ECDH_MULTI_LOG (INFO) << LOG_DESC (" tryToFinalize: syncIntersections and store success" )
158187 << printTaskInfo (m_taskState->task ());
188+ return true ;
159189}
160190
161191void CalculatorCache::syncIntersections ()
162192{
163193 ECDH_MULTI_LOG (INFO) << LOG_DESC (" syncIntersections" ) << printTaskInfo (m_taskState->task ());
164194 auto peers = m_taskState->task ()->getAllPeerParties ();
165- auto taskID = m_taskState->task ()->taskID ();
195+ auto taskID = m_taskState->task ()->id ();
166196 // notify task result
167197 if (!m_syncResult)
168198 {
@@ -178,8 +208,9 @@ void CalculatorCache::syncIntersections()
178208 if (_error && _error->errorCode () != 0 )
179209 {
180210 ECDH_MULTI_LOG (WARNING)
181- << LOG_DESC (" sync task result to peer failed" ) << LOG_KV (" peer" , peer)
182- << LOG_KV (" taskID" , taskID) << LOG_KV (" code" , _error->errorCode ())
211+ << LOG_DESC (" sync task result to peer failed" )
212+ << LOG_KV (" peer" , peer.first ) << LOG_KV (" taskID" , taskID)
213+ << LOG_KV (" code" , _error->errorCode ())
183214 << LOG_KV (" msg" , _error->errorMessage ());
184215 return ;
185216 }
@@ -197,12 +228,12 @@ void CalculatorCache::syncIntersections()
197228 for (auto & peer : peers)
198229 {
199230 m_config->generateAndSendPPCMessage (
200- _peer .first , taskID, message,
231+ peer .first , taskID, message,
201232 [taskID, peer](bcos::Error::Ptr&& _error) {
202233 if (_error && _error->errorCode () != 0 )
203234 {
204235 ECDH_MULTI_LOG (WARNING)
205- << LOG_DESC (" sync psi result to peer failed" ) << LOG_KV (" peer" , peer)
236+ << LOG_DESC (" sync psi result to peer failed" ) << LOG_KV (" peer" , peer. first )
206237 << LOG_KV (" taskID" , taskID) << LOG_KV (" code" , _error->errorCode ())
207238 << LOG_KV (" msg" , _error->errorMessage ());
208239 return ;
0 commit comments