Skip to content

Commit dcae000

Browse files
Chaluvadisyurkevi
authored andcommitted
Added monte_carlo_pi benchmark example
1 parent f44b835 commit dcae000

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed

examples/benchmarks/monte_carlo_pi.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#!/usr/bin/python
2+
3+
#######################################################
4+
# Copyright (c) 2024, ArrayFire
5+
# All rights reserved.
6+
#
7+
# This file is distributed under 3-clause BSD license.
8+
# The complete license agreement can be obtained at:
9+
# http://arrayfire.com/licenses/BSD-3-Clause
10+
########################################################
11+
12+
from random import random
13+
from time import time
14+
import arrayfire as af
15+
import sys
16+
17+
try:
18+
import numpy as np
19+
except ImportError:
20+
np = None
21+
22+
#alias range / xrange because xrange is faster than range in python2
23+
try:
24+
frange = xrange #Python2
25+
except NameError:
26+
frange = range #Python3
27+
28+
# Having the function outside is faster than the lambda inside
29+
def in_circle(x, y):
30+
return (x*x + y*y) < 1
31+
32+
def calc_pi_device(samples):
33+
x = af.randu((samples,))
34+
y = af.randu((samples,))
35+
return 4 * af.sum(in_circle(x, y)) / samples
36+
37+
def calc_pi_numpy(samples):
38+
np.random.seed(1)
39+
x = np.random.rand(samples).astype(np.float32)
40+
y = np.random.rand(samples).astype(np.float32)
41+
return 4. * np.sum(in_circle(x, y)) / samples
42+
43+
def calc_pi_host(samples):
44+
count = sum(1 for k in frange(samples) if in_circle(random(), random()))
45+
return 4 * float(count) / samples
46+
47+
def bench(calc_pi, samples=1000000, iters=25):
48+
func_name = calc_pi.__name__[8:]
49+
print("Monte carlo estimate of pi on %s with %d million samples: %f" % \
50+
(func_name, samples/1e6, calc_pi(samples)))
51+
52+
start = time()
53+
for k in frange(iters):
54+
calc_pi(samples)
55+
end = time()
56+
57+
print("Average time taken: %f ms" % (1000 * (end - start) / iters))
58+
59+
if __name__ == "__main__":
60+
if (len(sys.argv) > 1):
61+
af.set_device(int(sys.argv[1]))
62+
af.info()
63+
64+
bench(calc_pi_device)
65+
if np:
66+
bench(calc_pi_numpy)
67+
bench(calc_pi_host)

0 commit comments

Comments
 (0)