@@ -19,10 +19,11 @@ limitations under the License. */
19
19
20
20
namespace framework = paddle::framework;
21
21
22
- void do_sum (framework::ThreadPool* pool, std::atomic< int >* sum, int cnt) {
23
- std::vector<std::future< void >> fs;
22
+ void do_sum (std::vector< std::future< void >>* fs, std::mutex* mu,
23
+ std::atomic< int >* sum, int cnt) {
24
24
for (int i = 0 ; i < cnt; ++i) {
25
- fs.push_back (framework::Async ([sum]() { sum->fetch_add (1 ); }));
25
+ std::lock_guard<std::mutex> l (*mu);
26
+ fs->push_back (framework::Async ([sum]() { sum->fetch_add (1 ); }));
26
27
}
27
28
}
28
29
@@ -40,17 +41,21 @@ TEST(ThreadPool, ConcurrentInit) {
40
41
}
41
42
42
43
TEST (ThreadPool, ConcurrentRun) {
43
- framework::ThreadPool* pool = framework::ThreadPool::GetInstance ();
44
44
std::atomic<int > sum (0 );
45
45
std::vector<std::thread> threads;
46
+ std::vector<std::future<void >> fs;
47
+ std::mutex fs_mu;
46
48
int n = 50 ;
47
49
// sum = (n * (n + 1)) / 2
48
50
for (int i = 1 ; i <= n; ++i) {
49
- std::thread t (do_sum, pool , &sum, i);
51
+ std::thread t (do_sum, &fs, &fs_mu , &sum, i);
50
52
threads.push_back (std::move (t));
51
53
}
52
54
for (auto & t : threads) {
53
55
t.join ();
54
56
}
57
+ for (auto & t : fs) {
58
+ t.wait ();
59
+ }
55
60
EXPECT_EQ (sum, ((n + 1 ) * n) / 2 );
56
61
}
0 commit comments