diff --git a/branca/utilities.py b/branca/utilities.py index 14a1d46..ec280f3 100644 --- a/branca/utilities.py +++ b/branca/utilities.py @@ -66,40 +66,27 @@ def linear_gradient(hexList: List[str], nColors: int) -> List[str]: nColors where the colors are linearly interpolated between the (r, g, b) tuples that are given. """ - - def _scale(start, finish, length, i): - """ - Return the value correct value of a number that is in between start - and finish, for use in a loop of length *length*. - - """ - base = 16 - - fraction = float(i) / (length - 1) - raynge = int(finish, base) - int(start, base) - thex = hex(int(int(start, base) + fraction * raynge)).split("x")[-1] - if len(thex) != 2: - thex = "0" + thex - return thex - - allColors: List[str] = [] - # Separate (R, G, B) pairs. - for start, end in zip(hexList[:-1], hexList[1:]): - # Linearly interpolate between pair of hex ###### values and - # add to list. - nInterpolate = 765 - for index in range(nInterpolate): - r = _scale(start[1:3], end[1:3], nInterpolate, index) - g = _scale(start[3:5], end[3:5], nInterpolate, index) - b = _scale(start[5:7], end[5:7], nInterpolate, index) - allColors.append("".join(["#", r, g, b])) - - # Pick only nColors colors from the total list. + input_color_bytes = [ + [int(_hex[i : i + 2], 16) for i in (1, 3, 5)] for _hex in hexList + ] + # to have the same output as the previous version of this function we use + # a resolution of 765 'indexes' per color bin. + resolution = 765 + n_indexes = resolution * (len(hexList) - 1) result: List[str] = [] - for counter in range(nColors): - fraction = float(counter) / (nColors - 1) - index = int(fraction * (len(allColors) - 1)) - result.append(allColors[index]) + for counter in range(nColors - 1): + fraction_overall = float(counter) / (nColors - 1) + index_overall = int(fraction_overall * (n_indexes - 1)) + index_color_bin = index_overall % resolution + idx_input = index_overall // resolution + fraction = index_color_bin / (resolution - 1) + start = input_color_bytes[idx_input] + end = input_color_bytes[idx_input + 1] + new_color_bytes = [int(x + fraction * (y - x)) for x, y in zip(start, end)] + new_color_hexs = [hex(x)[2:].zfill(2) for x in new_color_bytes] + result.append("#" + "".join(new_color_hexs)) + + result.append(hexList[-1].lower()) return result diff --git a/tests/test_utilities.py b/tests/test_utilities.py index a5b83d0..754fbb1 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -1,6 +1,7 @@ import json import os from pathlib import Path +from typing import List import pytest @@ -169,3 +170,81 @@ def test_write_png_rgb(): ] png = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x04\x00\x00\x00\x02\x08\x06\x00\x00\x00\x7f\xa8}c\x00\x00\x00-IDATx\xda\x01"\x00\xdd\xff\x00\xff\xa7G\xffp\xff+\xff\x9e\x1cH\xff9\x90$\xff\x00\x93\xe9\xb8\xff\x0cz\xe2\xff\xc6\xca\xff\xff\xd4W\xd0\xffYw\x15\x95\xcf\xb9@D\x00\x00\x00\x00IEND\xaeB`\x82' # noqa E501 assert ut.write_png(image_rgb) == png + + +@pytest.mark.parametrize( + "hex_list, n_colors, expected_output", + [ + (["#000000", "#FFFFFF"], 2, ["#000000", "#ffffff"]), + (["#000000", "#FFFFFF"], 4, ["#000000", "#545454", "#a9a9a9", "#ffffff"]), + (["#FF0000", "#00FF00", "#0000FF"], 3, ["#ff0000", "#00ff00", "#0000ff"]), + ( + ["#FF0000", "#00FF00", "#0000FF"], + 4, + ["#ff0000", "#55a900", "#00aa54", "#0000ff"], + ), + ( + ["#000000", "#0000FF"], + 5, + ["#000000", "#00003f", "#00007f", "#0000bf", "#0000ff"], + ), + ( + ["#FFFFFF", "#000000"], + 5, + ["#ffffff", "#bfbfbf", "#7f7f7f", "#3f3f3f", "#000000"], + ), + ( + ["#FF0000", "#00FF00", "#0000FF"], + 5, + ["#ff0000", "#7f7f00", "#00ff00", "#007f7f", "#0000ff"], + ), + ( + ["#FF0000", "#00FF00", "#0000FF"], + 7, + [ + "#ff0000", + "#aa5400", + "#55a900", + "#00ff00", + "#00aa54", + "#0055a9", + "#0000ff", + ], + ), + ( + ["#abcdef", "#010603", "#f7f9f3"], + 4, + ["#abcdef", "#394851", "#525652", "#f7f9f3"], + ), + ( + ["#abcdef", "#010603", "#f7f9f3"], + 7, + [ + "#abcdef", + "#728aa0", + "#394851", + "#010603", + "#525652", + "#a4a7a2", + "#f7f9f3", + ], + ), + ( + ["#00abff", "#ff00ab", "#abff00", "#00abff"], + 4, + ["#00abff", "#ff00ab", "#abff00", "#00abff"], + ), + ( + ["#00abff", "#ff00ab", "#abff00", "#00abff"], + 6, + ["#00abff", "#9844cc", "#ee3288", "#bbcb22", "#66dd65", "#00abff"], + ), + ], +) +def test_linear_gradient( + hex_list: List[str], + n_colors: int, + expected_output: List[str], +): + result = ut.linear_gradient(hex_list, n_colors) + assert result == expected_output