|
| 1 | +import functools as ft |
| 2 | +import itertools as it |
| 3 | +import operator as op |
| 4 | +from math import gcd |
| 5 | +from typing import TYPE_CHECKING, Callable, Iterable, TypeVar |
| 6 | + |
| 7 | +import aoc |
| 8 | +from matrix import Fraction, Matrix, Num |
| 9 | + |
| 10 | +if TYPE_CHECKING: |
| 11 | + from _typeshed import SupportsRichComparisonT |
| 12 | + |
| 13 | + |
| 14 | +type Lights = list[bool] |
| 15 | +type Button = set[int] |
| 16 | +type Joltage = tuple[int, ...] |
| 17 | +type Machine = tuple[Lights, list[Button], Joltage] |
| 18 | +T = TypeVar('T') |
| 19 | + |
| 20 | + |
| 21 | +def main(): |
| 22 | + print(sum_toggles(parsed())) |
| 23 | + print(sum_jolts(parsed())) |
| 24 | + |
| 25 | + |
| 26 | +def sum_toggles(machines: list[Machine]): |
| 27 | + return sum(num_presses(lights, btns) or 0 for lights, btns, _ in machines) |
| 28 | + |
| 29 | + |
| 30 | +def sum_jolts(machines: list[Machine]): |
| 31 | + return sum(joltage_cost(btns, joltage) for _, btns, joltage in machines) |
| 32 | + |
| 33 | + |
| 34 | +def parsed(inp: str | None = None): |
| 35 | + lines = inp.splitlines() if inp else aoc.input_lines() |
| 36 | + return [parse(line) for line in lines] |
| 37 | + |
| 38 | + |
| 39 | +def parse(line: str) -> Machine: |
| 40 | + light_str, *btns_str, jolt_str = line.split(' ') |
| 41 | + return ( |
| 42 | + [s == '#' for s in light_str[1:-1]], |
| 43 | + [set(map(int, s[1:-1].split(','))) for s in btns_str], |
| 44 | + tuple(map(int, jolt_str[1:-1].split(','))), |
| 45 | + ) |
| 46 | + |
| 47 | + |
| 48 | +def num_presses(lights: Lights, buttons: list[Button]): |
| 49 | + def press(btns: tuple[Button, ...]) -> Joltage: |
| 50 | + return tuple(sum(i in b for b in btns) for i in range(len(lights))) |
| 51 | + |
| 52 | + def pattern(jolts: Joltage) -> Lights: |
| 53 | + return [bool(n % 2) for n in jolts] |
| 54 | + |
| 55 | + for n in range(len(buttons) + 1): |
| 56 | + for combo in it.combinations(buttons, n): |
| 57 | + if pattern(press(combo)) == lights: |
| 58 | + return n |
| 59 | + |
| 60 | + |
| 61 | +def joltage_cost(buttons: list[Button], joltage: Joltage): |
| 62 | + def groupby(itr: Iterable[T], key: Callable[[T], 'SupportsRichComparisonT']): |
| 63 | + return {k: list(v) for k, v in it.groupby(sorted(itr, key=key), key=key)} |
| 64 | + |
| 65 | + def sub_halve(j_a: Joltage, j_b: Joltage) -> Joltage: |
| 66 | + return tuple((a - b) // 2 for a, b, in zip(j_a, j_b)) |
| 67 | + |
| 68 | + def press(btns: tuple[Button, ...]) -> Joltage: |
| 69 | + return tuple(sum(i in b for b in btns) for i in range(len(joltage))) |
| 70 | + |
| 71 | + def pattern(jolts: Joltage) -> Joltage: |
| 72 | + return tuple(n % 2 for n in jolts) |
| 73 | + |
| 74 | + all_btn_combos = (combo for n in range(len(buttons) + 1) for combo in it.combinations(buttons, n)) |
| 75 | + press_patterns = groupby(all_btn_combos, lambda btns: pattern(press(btns))) |
| 76 | + |
| 77 | + @ft.cache |
| 78 | + def cost(jolts: Joltage) -> int: |
| 79 | + if not any(jolts): |
| 80 | + return 0 |
| 81 | + elif any(j < 0 for j in jolts) or pattern(jolts) not in press_patterns: |
| 82 | + return sum(joltage) |
| 83 | + else: |
| 84 | + btn_combos = press_patterns[pattern(jolts)] |
| 85 | + return min(len(btns) + 2 * cost(sub_halve(jolts, press(btns))) for btns in btn_combos) |
| 86 | + |
| 87 | + return cost(joltage) |
| 88 | + |
| 89 | + |
| 90 | +def as_introw(row: list[Num]) -> list[int]: |
| 91 | + denoms = [Fraction(v).denominator for v in row] |
| 92 | + mul = ft.reduce(op.mul, denoms) |
| 93 | + res = [int(v * mul) for v in row] |
| 94 | + div = gcd(*res) |
| 95 | + return [v // div for v in res] |
| 96 | + |
| 97 | + |
| 98 | +def joltage_mtx(buttons: list[Button], joltage: Joltage): |
| 99 | + m = Matrix([ |
| 100 | + [int(i in btn) for btn in buttons] |
| 101 | + for i in range(len(joltage)) |
| 102 | + ]) |
| 103 | + jolt = Matrix([[j] for j in joltage]) |
| 104 | + bounds = [min(int(jolt[i][0]) for i in btn) for btn in buttons] |
| 105 | + |
| 106 | + sys_eq = m.chain(jolt) |
| 107 | + # reduced = sys_eq.reduce() |
| 108 | + # reduced = Matrix([as_introw(row) for row in reduced]) # v4+ |
| 109 | + reduced = sys_eq.reduce_int() |
| 110 | + |
| 111 | + # move free variables to the end (reorder buttons) |
| 112 | + for i in range(reduced.h): |
| 113 | + if reduced[i, i] != 1: |
| 114 | + n = next(j for j in range(i, m.w) if reduced[i, j]) |
| 115 | + reduced = reduced.swap_col(i, n) |
| 116 | + buttons[i], buttons[n] = buttons[n], buttons[i] |
| 117 | + bounds[i], bounds[n] = bounds[n], bounds[i] |
| 118 | + m = m.swap_col(i, n) |
| 119 | + |
| 120 | + free_vars = bounds[reduced.h:m.w] |
| 121 | + free_cols = reduced.subm((0, reduced.h), (reduced.h, m.w)) |
| 122 | + |
| 123 | + def btn_presses(cand_fvars: tuple[int, ...]): |
| 124 | + for i in range(free_cols.h): |
| 125 | + matmul = sum(s * o for s, o in zip(free_cols._mat[i], cand_fvars)) |
| 126 | + diff = reduced._mat[i][-1] - matmul |
| 127 | + presses, rem = divmod(diff, int(reduced._mat[i][i])) |
| 128 | + if diff < 0 or rem: |
| 129 | + yield -2**63 |
| 130 | + return |
| 131 | + yield int(presses) |
| 132 | + |
| 133 | + for presses in cand_fvars: |
| 134 | + yield presses |
| 135 | + |
| 136 | + def solutions(): |
| 137 | + if free_vars: |
| 138 | + free_var_candidates = it.product(*(range(v + 1) for v in free_vars)) |
| 139 | + # Hot loop |
| 140 | + for cand_fvars in free_var_candidates: |
| 141 | + |
| 142 | + # v5 |
| 143 | + num_presses = sum(btn_presses(cand_fvars)) |
| 144 | + if num_presses >= 0: |
| 145 | + yield num_presses |
| 146 | + |
| 147 | + # # v4 |
| 148 | + # num_presses = sum(cand_fvars) |
| 149 | + # for i in range(free_cols.h): |
| 150 | + # matmul = sum(s * o for s, o in zip(free_cols._mat[i], cand_fvars)) |
| 151 | + # diff = reduced._mat[i][-1] - matmul |
| 152 | + # presses, rem = divmod(diff, int(reduced._mat[i][i])) |
| 153 | + # if diff < 0 or rem: |
| 154 | + # num_presses = -1 |
| 155 | + # break |
| 156 | + # num_presses += presses |
| 157 | + # if num_presses >= 0: |
| 158 | + # yield num_presses |
| 159 | + |
| 160 | + # # v3 |
| 161 | + # num_presses = sum(cand_fvars) |
| 162 | + # for i in range(free_cols.h): |
| 163 | + # matmul = sum(s * o for s, o in zip(free_cols._mat[i], cand_fvars)) |
| 164 | + # presses = reduced._mat[i][-1] - matmul |
| 165 | + # if presses < 0 or presses % 1: |
| 166 | + # num_presses = -1 |
| 167 | + # break |
| 168 | + # num_presses += presses |
| 169 | + # if num_presses >= 0: |
| 170 | + # yield num_presses |
| 171 | + |
| 172 | + # # v2 |
| 173 | + # presses = [ |
| 174 | + # reduced._mat[i][-1] - sum(s * o for s, o in zip(free_cols._mat[i], cand_fvars)) |
| 175 | + # for i in range(free_cols.h) |
| 176 | + # ] |
| 177 | + # if not any(presses < 0 or presses % 1 for presses in presses): |
| 178 | + # yield sum(presses + list(cand_fvars)) |
| 179 | + |
| 180 | + # # v1 |
| 181 | + # free_var_contrib = free_cols @ Matrix[cand_fvars].trans() |
| 182 | + # presses = Matrix[*(reduced.tail(1) - free_var_contrib)] |
| 183 | + # if any(presses < 0 or presses % 1 for presses in presses.col(0)): |
| 184 | + # continue |
| 185 | + # yield sum(presses.reduce_nums().col(0) + list(cand_fvars)) |
| 186 | + |
| 187 | + else: |
| 188 | + num_presses, rem = divmod(sum(reduced.col(-1)), 1) |
| 189 | + assert not rem |
| 190 | + yield int(num_presses) |
| 191 | + |
| 192 | + return min(solutions()) |
| 193 | + |
| 194 | + |
| 195 | +if __name__ == "__main__": |
| 196 | + main() |
| 197 | + |
| 198 | + |
| 199 | +import pytest |
| 200 | + |
| 201 | +example = aoc.heredoc(""" |
| 202 | + [.##.] (3) (1,3) (2) (2,3) (0,2) (0,1) {3,5,4,7} |
| 203 | + [...#.] (0,2,3,4) (2,3) (0,4) (0,1,2) (1,2,3,4) {7,5,12,7,2} |
| 204 | + [.###.#] (0,1,2,3,4) (0,3,4) (0,1,2,4,5) (1,2) {10,11,11,5,10,5} |
| 205 | +""") |
| 206 | + |
| 207 | + |
| 208 | +@pytest.mark.parametrize(['machine', 'num'], [ |
| 209 | + (parsed(example)[0], 2), |
| 210 | + (parsed(example)[1], 3), |
| 211 | + (parsed(example)[2], 2), |
| 212 | +]) |
| 213 | +def test_light_cfg_btns(machine, num): |
| 214 | + assert num_presses(machine[0], machine[1]) == num |
| 215 | + |
| 216 | + |
| 217 | +def test_as_introw(): |
| 218 | + row = [1, 0, 0, -1, Fraction(-1, 2), -10] |
| 219 | + assert as_introw(row) == [2, 0, 0, -2, -1, -20] |
| 220 | + |
| 221 | + row = [0, Fraction(1, 3), Fraction(-1, 2), -10] |
| 222 | + assert as_introw(row) == [0, 2, -3, -60] |
| 223 | + |
| 224 | + |
| 225 | +@pytest.mark.parametrize(['machine', 'num'], [ |
| 226 | + (parsed(example)[0], 10), |
| 227 | + (parsed(example)[1], 12), |
| 228 | + (parsed(example)[2], 11), |
| 229 | + (parse('[#.##] (1,3) (0,2,3) {0,13,0,13}'), 13), |
| 230 | + (parse('[##..] (0,2) (0,1) (0,3) {22,14,4,4}'), 22), |
| 231 | + (parse('[###.] (1,2,3) (0,1,2) {15,145,145,130}'), 145), |
| 232 | + (parse('[##..] (0,3) (2) (0,2) (1,2) {20,106,123,13}'), 136), |
| 233 | + (parse('[##..] (2,3) (1,2) (1,2,3) (0) {7,147,167,27}'), 174), |
| 234 | + (parse('[..#.###.#.] (0,3,4,5) (1,4,6) (2,9) (0,4) (2,4,7,8) (0,2,3,4,5,6,7,8,9) (1,6) (1,2,5,6,7) (0,4,7,8) (0,1,2,3,5,8,9) (0,3,4,5,6,7,8,9) (4,6,9) {56,51,67,27,82,44,70,56,49,58}'), 132), |
| 235 | + (parse('[.#..#...##] (2,8) (1,2,5,7,8) (0,3,4,6,8) (1,3,4,5,6,7,8,9) (0,5,6,8,9) (1,2,3,4,7,8) (2,7,9) (1,4,5,7) (2,5,6,9) (0,2,3,6) (1,3,6) (0,1,3,4,6,8) {53,39,56,61,47,41,75,47,59,45}'), 113), |
| 236 | + (parse('[.##..] (0,1) (0,1,2,3) (4) (0,1,4) (0,1,2) (2,4) (0,3) {65,57,43,22,34}'), 75), |
| 237 | + (parse('[#.###] (2,3,4) (0,3) (1,4) (2,4) (0,1,3) (0,2,4) (0,4) {57,17,20,43,39}'), 70), |
| 238 | + (parse('[..#...] (1,2,4,5) (0,2) (0,3,4) (0,1,2,3,5) (2,4,5) (1,3) (0,1) {57,55,62,49,48,49}'), 90), |
| 239 | +]) |
| 240 | +def test_joltage_quick(machine, num): |
| 241 | + assert joltage_mtx(machine[1], machine[2]) == joltage_cost(machine[1], machine[2]) == num |
| 242 | + |
| 243 | + |
| 244 | +@pytest.mark.parametrize(['machine', 'num'], [ |
| 245 | + (parse('[.####....#] (0,1,2,4,6,7,9) (0,1,3,4,5,7,8,9) (3,4,6,8) (2,5,6,8) (0,2,3,5,7,8,9) (0,1,5,7,9) (1,2,6) (1,2,3,4,5,9) (0,1,2,4,5,6,8,9) (1,2,4,5,6,9) (1,3) (0,1,2,6,7,8,9) (0,1,2,3,4) {82,113,122,47,76,67,90,53,59,91}'), 135), |
| 246 | + (parse('[..#....###] (1,4,5,7,8,9) (2,3,4,5,7) (0,4) (0,1,6,8) (3,5) (0,2,3,4,5,6,7,8) (0,3,5,6,7,9) (0,1,2,3,5,6,7,8) (1,4,5,6) (0,1,2,3,5,6,8,9) (0,3,5,9) (0,2,3,4,5,7,9) (1,8,9) {78,30,51,90,71,104,54,67,29,51}'), 126), |
| 247 | + (parse('[..#.#.#..#] (1,3,4,5,6,7,8,9) (0,2,3,4,7) (0,1,2,3,7,8,9) (0,1,2,3,4,9) (1,2,3,4,6,7,8) (0,1,2,3,5,6,7,8,9) (2,3,7) (1,2,5) (0,1,2,3,4,5,6,7,8) (1,4,5,8,9) (0,1,2,3,5,6,8,9) (0,1,3,7,8) (0,2,4,6,8,9) {87,134,107,105,80,83,69,86,116,84}'), 136), |
| 248 | + (parse('[.#.#.###.#] (0,5) (0,1,3,4,6,7,8) (0,1,2,5,6,7,8) (0,1,2,6,7,9) (3,4,6,8) (2,3,4,5,6) (0,1,2,3,4,6,8,9) (3,5,7) (1,3,6) (0,1,5,7,8,9) (2,3,5,6,7,8,9) (1,4) (4,8,9) {61,79,46,93,91,61,86,56,71,37}'), 129), |
| 249 | +]) |
| 250 | +def test_joltage_slow(machine, num): |
| 251 | + assert joltage_mtx(machine[1], machine[2]) == joltage_cost(machine[1], machine[2]) == num |
0 commit comments