@@ -10,7 +10,7 @@ class Matrix(object):
1010 """Object representation of the item-item matrix
1111 """
1212
13- def __init__ (self , data , combinfunc , symmetric = False , diagonal = None , num_processes = 1 ):
13+ def __init__ (self , data , combinfunc , symmetric = False , diagonal = None ):
1414 """Takes a list of data and generates a 2D-matrix using the supplied
1515 combination function to calculate the values.
1616
@@ -28,21 +28,11 @@ def __init__(self, data, combinfunc, symmetric=False, diagonal=None, num_process
2828 could be the function "x-y". Then each diagonal cell
2929 will be "0". If this value is set to None, then the
3030 diagonal will be calculated. Default: None
31- num_processes
32- - If you want to use multiprocessing to split up the work
33- and run combinfunc() in parallel, specify num_processes
34- > 1 and this number of workers will be spun up, the work
35- split up amongst them evenly. Default: 1
3631 """
3732 self .data = data
3833 self .combinfunc = combinfunc
3934 self .symmetric = symmetric
4035 self .diagonal = diagonal
41- self .num_processes = num_processes
42- self .use_multiprocessing = num_processes > 1
43- if self .use_multiprocessing :
44- self .task_queue = Queue ()
45- self .done_queue = Queue ()
4636
4737 def worker (self ):
4838 """Multiprocessing task function run by worker processes
@@ -59,15 +49,29 @@ def worker(self):
5949 current_process ().name ,
6050 tasks_completed )
6151
62- def genmatrix (self ):
52+ def genmatrix (self , num_processes = 1 ):
53+ """Actually generate the matrix
54+
55+ PARAMETERS
56+ num_processes
57+ - If you want to use multiprocessing to split up the work
58+ and run combinfunc() in parallel, specify num_processes
59+ > 1 and this number of workers will be spun up, the work
60+ split up amongst them evenly. Default: 1
61+ """
62+ use_multiprocessing = num_processes > 1
63+ if use_multiprocessing :
64+ self .task_queue = Queue ()
65+ self .done_queue = Queue ()
66+
6367 self .matrix = []
6468 logger .info ("Generating matrix for %s items - O(n^2)" , len (self .data ))
65- if self . use_multiprocessing :
66- logger .info ("Using multiprocessing on %s processes!" , self . num_processes )
69+ if use_multiprocessing :
70+ logger .info ("Using multiprocessing on %s processes!" , num_processes )
6771
68- if self . use_multiprocessing :
69- logger .info ("Spinning up %s workers" , self . num_processes )
70- processes = [Process (target = self .worker ) for i in range (self . num_processes )]
72+ if use_multiprocessing :
73+ logger .info ("Spinning up %s workers" , num_processes )
74+ processes = [Process (target = self .worker ) for i in range (num_processes )]
7175 [process .start () for process in processes ]
7276
7377 for row_index , item in enumerate (self .data ):
@@ -76,7 +80,7 @@ def genmatrix(self):
7680 len (self .data ),
7781 100.0 * row_index / len (self .data ))
7882 row = {}
79- if self . use_multiprocessing :
83+ if use_multiprocessing :
8084 num_tasks_queued = num_tasks_completed = 0
8185 for col_index , item2 in enumerate (self .data ):
8286 if self .diagonal is not None and col_index == row_index :
@@ -88,14 +92,14 @@ def genmatrix(self):
8892 pass
8993 # Otherwise, this cell is not on the diagonal and we do indeed
9094 # need to call combinfunc()
91- elif self . use_multiprocessing :
95+ elif use_multiprocessing :
9296 # Add that thing to the task queue!
9397 self .task_queue .put ((col_index , item , item2 ))
9498 num_tasks_queued += 1
9599 # Start grabbing the results as we go, so as not to stuff all of
96100 # the worker args into memory at once (as Queue.get() is a
97101 # blocking operation)
98- if num_tasks_queued > self . num_processes :
102+ if num_tasks_queued > num_processes :
99103 col_index , result = self .done_queue .get ()
100104 self .done_queue .task_done ()
101105 row [col_index ] = result
@@ -112,7 +116,7 @@ def genmatrix(self):
112116 # post-process symmetric "lower left triangle"
113117 row [col_index ] = self .matrix [col_index ][row_index ]
114118
115- if self . use_multiprocessing :
119+ if use_multiprocessing :
116120 # Grab the remaining worker task results
117121 while num_tasks_completed < num_tasks_queued :
118122 col_index , result = self .done_queue .get ()
@@ -123,9 +127,9 @@ def genmatrix(self):
123127 row_indexed = [row [index ] for index in range (len (self .data ))]
124128 self .matrix .append (row_indexed )
125129
126- if self . use_multiprocessing :
127- logger .info ("Stopping/joining %s workers" , self . num_processes )
128- [self .task_queue .put ('STOP' ) for i in range (self . num_processes )]
130+ if use_multiprocessing :
131+ logger .info ("Stopping/joining %s workers" , num_processes )
132+ [self .task_queue .put ('STOP' ) for i in range (num_processes )]
129133 [process .join () for process in processes ]
130134
131135 logger .info ("Matrix generated" )
0 commit comments