|
| 1 | + |
| 2 | +import logging |
| 3 | +from multiprocessing import Process, Queue, current_process |
| 4 | + |
| 5 | + |
| 6 | +logger = logging.getLogger(__name__) |
| 7 | + |
| 8 | + |
| 9 | +class Matrix(object): |
| 10 | + """Object representation of the item-item matrix |
| 11 | + """ |
| 12 | + |
| 13 | + def __init__(self, data, combinfunc, symmetric=False, diagonal=None): |
| 14 | + """Takes a list of data and generates a 2D-matrix using the supplied |
| 15 | + combination function to calculate the values. |
| 16 | +
|
| 17 | + PARAMETERS |
| 18 | + data - the list of items |
| 19 | + combinfunc - the function that is used to calculate teh value in a |
| 20 | + cell. It has to cope with two arguments. |
| 21 | + symmetric - Whether it will be a symmetric matrix along the diagonal. |
| 22 | + For example, if the list contains integers, and the |
| 23 | + combination function is abs(x-y), then the matrix will |
| 24 | + be symmetric. |
| 25 | + Default: False |
| 26 | + diagonal - The value to be put into the diagonal. For some |
| 27 | + functions, the diagonal will stay constant. An example |
| 28 | + could be the function "x-y". Then each diagonal cell |
| 29 | + will be "0". If this value is set to None, then the |
| 30 | + diagonal will be calculated. Default: None |
| 31 | + """ |
| 32 | + self.data = data |
| 33 | + self.combinfunc = combinfunc |
| 34 | + self.symmetric = symmetric |
| 35 | + self.diagonal = diagonal |
| 36 | + |
| 37 | + def worker(self): |
| 38 | + """Multiprocessing task function run by worker processes |
| 39 | + """ |
| 40 | + tasks_completed = 0 |
| 41 | + for task in iter(self.task_queue.get, 'STOP'): |
| 42 | + col_index, item, item2 = task |
| 43 | + result = (col_index, self.combinfunc(item, item2)) |
| 44 | + self.task_queue.task_done() |
| 45 | + self.done_queue.put(result) |
| 46 | + tasks_completed += 1 |
| 47 | + self.task_queue.task_done() |
| 48 | + logger.info("Worker %s performed %s tasks", |
| 49 | + current_process().name, |
| 50 | + tasks_completed) |
| 51 | + |
| 52 | + def genmatrix(self, num_processes=1): |
| 53 | + """Actually generate the matrix |
| 54 | +
|
| 55 | + PARAMETERS |
| 56 | + num_processes |
| 57 | + - If you want to use multiprocessing to split up the work |
| 58 | + and run combinfunc() in parallel, specify num_processes |
| 59 | + > 1 and this number of workers will be spun up, the work |
| 60 | + split up amongst them evenly. Default: 1 |
| 61 | + """ |
| 62 | + use_multiprocessing = num_processes > 1 |
| 63 | + if use_multiprocessing: |
| 64 | + self.task_queue = Queue() |
| 65 | + self.done_queue = Queue() |
| 66 | + |
| 67 | + self.matrix = [] |
| 68 | + logger.info("Generating matrix for %s items - O(n^2)", len(self.data)) |
| 69 | + if use_multiprocessing: |
| 70 | + logger.info("Using multiprocessing on %s processes!", num_processes) |
| 71 | + |
| 72 | + if use_multiprocessing: |
| 73 | + logger.info("Spinning up %s workers", num_processes) |
| 74 | + processes = [Process(target=self.worker) for i in range(num_processes)] |
| 75 | + [process.start() for process in processes] |
| 76 | + |
| 77 | + for row_index, item in enumerate(self.data): |
| 78 | + logger.debug("Generating row %s/%s (%0.2f%%)", |
| 79 | + row_index, |
| 80 | + len(self.data), |
| 81 | + 100.0 * row_index / len(self.data)) |
| 82 | + row = {} |
| 83 | + if use_multiprocessing: |
| 84 | + num_tasks_queued = num_tasks_completed = 0 |
| 85 | + for col_index, item2 in enumerate(self.data): |
| 86 | + if self.diagonal is not None and col_index == row_index: |
| 87 | + # This is a cell on the diagonal |
| 88 | + row[col_index] = self.diagonal |
| 89 | + elif self.symmetric and col_index < row_index: |
| 90 | + # The matrix is symmetric and we are "in the lower left |
| 91 | + # triangle" - fill this in after (in case of multiprocessing) |
| 92 | + pass |
| 93 | + # Otherwise, this cell is not on the diagonal and we do indeed |
| 94 | + # need to call combinfunc() |
| 95 | + elif use_multiprocessing: |
| 96 | + # Add that thing to the task queue! |
| 97 | + self.task_queue.put((col_index, item, item2)) |
| 98 | + num_tasks_queued += 1 |
| 99 | + # Start grabbing the results as we go, so as not to stuff all of |
| 100 | + # the worker args into memory at once (as Queue.get() is a |
| 101 | + # blocking operation) |
| 102 | + if num_tasks_queued > num_processes: |
| 103 | + col_index, result = self.done_queue.get() |
| 104 | + self.done_queue.task_done() |
| 105 | + row[col_index] = result |
| 106 | + num_tasks_completed += 1 |
| 107 | + else: |
| 108 | + # Otherwise do it here, in line |
| 109 | + row[col_index] = self.combinfunc(item, item2) |
| 110 | + |
| 111 | + if self.symmetric: |
| 112 | + # One more iteration to get symmetric lower left triangle |
| 113 | + for col_index, item2 in enumerate(self.data): |
| 114 | + if col_index >= row_index: |
| 115 | + break |
| 116 | + # post-process symmetric "lower left triangle" |
| 117 | + row[col_index] = self.matrix[col_index][row_index] |
| 118 | + |
| 119 | + if use_multiprocessing: |
| 120 | + # Grab the remaining worker task results |
| 121 | + while num_tasks_completed < num_tasks_queued: |
| 122 | + col_index, result = self.done_queue.get() |
| 123 | + self.done_queue.task_done() |
| 124 | + row[col_index] = result |
| 125 | + num_tasks_completed += 1 |
| 126 | + |
| 127 | + row_indexed = [row[index] for index in range(len(self.data))] |
| 128 | + self.matrix.append(row_indexed) |
| 129 | + |
| 130 | + if use_multiprocessing: |
| 131 | + logger.info("Stopping/joining %s workers", num_processes) |
| 132 | + [self.task_queue.put('STOP') for i in range(num_processes)] |
| 133 | + [process.join() for process in processes] |
| 134 | + |
| 135 | + logger.info("Matrix generated") |
| 136 | + |
| 137 | + def __str__(self): |
| 138 | + """ |
| 139 | + Prints out a 2-dimensional list of data cleanly. |
| 140 | + This is useful for debugging. |
| 141 | +
|
| 142 | + PARAMETERS |
| 143 | + data - the 2D-list to display |
| 144 | + """ |
| 145 | + # determine maximum length |
| 146 | + maxlen = 0 |
| 147 | + colcount = len(self.data[0]) |
| 148 | + for col in self.data: |
| 149 | + for cell in col: |
| 150 | + maxlen = max(len(str(cell)), maxlen) |
| 151 | + format = " %%%is |" % maxlen |
| 152 | + format = "|" + format * colcount |
| 153 | + rows = [format % tuple(row) for row in self.data] |
| 154 | + return "\n".join(rows) |
0 commit comments