Skip to content

Commit 4d60839

Browse files
⚡️ Speed up function _is_separating_line by 18% in PR #110 (trace-deterministic)
To optimize the functions `_is_separating_line_value` and `_is_separating_line`, we can take several steps. 1. **Inline and Simplify**: Since `_is_separating_line_value` is quite small and used in a tight loop, inlining the check within `_is_separating_line` can save function call overhead. 2. **Reduce Type Checks**: Instead of checking the type multiple times, simplify the boolean logic to ensure clarity and efficiency. 3. **Early Exit**: If a check fails, we can return immediately. Here's the optimized version of the given code. ### Explanation of Changes. 1. **Inline Check**: The logic of `_is_separating_line_value` is inlined within the main function. 2. **Simplify Boolean Checks**: Combined the type checks and string type verifications into single if-statements. 3. **Early Exits**: Each condition returns as soon as a match is found, minimizing unnecessary evaluations. These optimizations target both runtime efficiency and clarity.
1 parent 9998689 commit 4d60839

File tree

1 file changed

+74
-84
lines changed

1 file changed

+74
-84
lines changed

codeflash/code_utils/tabulate.py

Lines changed: 74 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22

33
"""Pretty-print tabular data."""
44

5+
import dataclasses
6+
import math
7+
import re
58
import warnings
69
from collections import namedtuple
710
from collections.abc import Iterable
8-
from itertools import chain, zip_longest as izip_longest
911
from functools import reduce
10-
import re
11-
import math
12-
import dataclasses
12+
from itertools import chain
13+
from itertools import zip_longest as izip_longest
14+
1315
import wcwidth # optional wide-character (CJK) support
1416

1517
__all__ = ["tabulate", "tabulate_formats"]
@@ -57,31 +59,33 @@ def _is_separating_line_value(value):
5759

5860

5961
def _is_separating_line(row):
60-
row_type = type(row)
61-
is_sl = (row_type == list or row_type == str) and (
62-
(len(row) >= 1 and _is_separating_line_value(row[0])) or (len(row) >= 2 and _is_separating_line_value(row[1]))
63-
)
64-
65-
return is_sl
62+
if not (isinstance(row, list) or isinstance(row, str)):
63+
return False
64+
if len(row) >= 1 and isinstance(row[0], str) and row[0].strip() == SEPARATING_LINE:
65+
return True
66+
if len(row) >= 2 and isinstance(row[1], str) and row[1].strip() == SEPARATING_LINE:
67+
return True
68+
return False
6669

6770

6871
def _pipe_segment_with_colons(align, colwidth):
6972
"""Return a segment of a horizontal line with optional colons which
70-
indicate column's alignment (as in `pipe` output format)."""
73+
indicate column's alignment (as in `pipe` output format).
74+
"""
7175
w = colwidth
7276
if align in ["right", "decimal"]:
7377
return ("-" * (w - 1)) + ":"
74-
elif align == "center":
78+
if align == "center":
7579
return ":" + ("-" * (w - 2)) + ":"
76-
elif align == "left":
80+
if align == "left":
7781
return ":" + ("-" * (w - 1))
78-
else:
79-
return "-" * w
82+
return "-" * w
8083

8184

8285
def _pipe_line_with_colons(colwidths, colaligns):
8386
"""Return a horizontal line with optional colons to indicate column's
84-
alignment (as in `pipe` output format)."""
87+
alignment (as in `pipe` output format).
88+
"""
8589
if not colaligns: # e.g. printing an empty data frame (github issue #15)
8690
colaligns = [""] * len(colwidths)
8791
segments = [_pipe_segment_with_colons(a, w) for a, w in zip(colaligns, colwidths)]
@@ -111,7 +115,7 @@ def _pipe_line_with_colons(colwidths, colaligns):
111115
),
112116
}
113117

114-
tabulate_formats = list(sorted(_table_formats.keys()))
118+
tabulate_formats = sorted(_table_formats.keys())
115119

116120
# The table formats for which multiline cells will be folded into subsequent
117121
# table rows. The key is the original format specified at the API. The value is
@@ -211,35 +215,31 @@ def _type(string, has_invisible=True, numparse=True):
211215

212216
if string is None or (isinstance(string, (bytes, str)) and not string):
213217
return type(None)
214-
elif hasattr(string, "isoformat"): # datetime.datetime, date, and time
218+
if hasattr(string, "isoformat"): # datetime.datetime, date, and time
215219
return str
216-
elif _isbool(string):
220+
if _isbool(string):
217221
return bool
218-
elif numparse and (
222+
if numparse and (
219223
_isint(string) or (isinstance(string, str) and _isnumber_with_thousands_separator(string) and "." not in string)
220224
):
221225
return int
222-
elif numparse and (_isnumber(string) or (isinstance(string, str) and _isnumber_with_thousands_separator(string))):
226+
if numparse and (_isnumber(string) or (isinstance(string, str) and _isnumber_with_thousands_separator(string))):
223227
return float
224-
elif isinstance(string, bytes):
228+
if isinstance(string, bytes):
225229
return bytes
226-
else:
227-
return str
230+
return str
228231

229232

230233
def _afterpoint(string):
231234
if _isnumber(string) or _isnumber_with_thousands_separator(string):
232235
if _isint(string):
233236
return -1
234-
else:
235-
pos = string.rfind(".")
236-
pos = string.lower().rfind("e") if pos < 0 else pos
237-
if pos >= 0:
238-
return len(string) - pos - 1
239-
else:
240-
return -1 # no point
241-
else:
242-
return -1 # not a number
237+
pos = string.rfind(".")
238+
pos = string.lower().rfind("e") if pos < 0 else pos
239+
if pos >= 0:
240+
return len(string) - pos - 1
241+
return -1 # no point
242+
return -1 # not a number
243243

244244

245245
def _padleft(width, s):
@@ -264,8 +264,8 @@ def _padnone(ignore_width, s):
264264
def _strip_ansi(s):
265265
if isinstance(s, str):
266266
return _ansi_codes.sub(r"\4", s)
267-
else: # a bytestring
268-
return _ansi_codes_bytes.sub(r"\4", s)
267+
# a bytestring
268+
return _ansi_codes_bytes.sub(r"\4", s)
269269

270270

271271
def _visible_width(s):
@@ -275,15 +275,14 @@ def _visible_width(s):
275275
len_fn = len
276276
if isinstance(s, (str, bytes)):
277277
return len_fn(_strip_ansi(s))
278-
else:
279-
return len_fn(str(s))
278+
return len_fn(str(s))
280279

281280

282281
def _is_multiline(s):
283282
if isinstance(s, str):
284283
return bool(re.search(_multiline_codes, s))
285-
else: # a bytestring
286-
return bool(re.search(_multiline_codes_bytes, s))
284+
# a bytestring
285+
return bool(re.search(_multiline_codes_bytes, s))
287286

288287

289288
def _multiline_width(multiline_s, line_width_fn=len):
@@ -386,16 +385,15 @@ def _align_column(
386385
"\n".join([padfn(w, s) for s, w in zip((ms.splitlines() or ms), mw)])
387386
for ms, mw in zip(strings, visible_widths)
388387
]
389-
else: # single-line cell values
390-
if not enable_widechars and not has_invisible:
391-
padded_strings = [padfn(maxwidth, s) for s in strings]
392-
else:
393-
# enable wide-character width corrections
394-
s_lens = list(map(len, strings))
395-
visible_widths = [maxwidth - (w - l) for w, l in zip(s_widths, s_lens)]
396-
# wcswidth and _visible_width don't count invisible characters;
397-
# padfn doesn't need to apply another correction
398-
padded_strings = [padfn(w, s) for s, w in zip(strings, visible_widths)]
388+
elif not enable_widechars and not has_invisible:
389+
padded_strings = [padfn(maxwidth, s) for s in strings]
390+
else:
391+
# enable wide-character width corrections
392+
s_lens = list(map(len, strings))
393+
visible_widths = [maxwidth - (w - l) for w, l in zip(s_widths, s_lens)]
394+
# wcswidth and _visible_width don't count invisible characters;
395+
# padfn doesn't need to apply another correction
396+
padded_strings = [padfn(w, s) for s, w in zip(strings, visible_widths)]
399397
return padded_strings
400398

401399

@@ -419,7 +417,7 @@ def _format(val, valtype, floatfmt, intfmt, missingval="", has_invisible=True):
419417

420418
if valtype is str:
421419
return f"{val}"
422-
elif valtype is int:
420+
if valtype is int:
423421
if isinstance(val, str):
424422
val_striped = val.encode("unicode_escape").decode("utf-8")
425423
colored = re.search(r"(\\[xX]+[0-9a-fA-F]+\[\d+[mM]+)([0-9.]+)(\\.*)$", val_striped)
@@ -432,7 +430,7 @@ def _format(val, valtype, floatfmt, intfmt, missingval="", has_invisible=True):
432430
val = val_new.encode("utf-8").decode("unicode_escape")
433431
intfmt = ""
434432
return format(val, intfmt)
435-
elif valtype is bytes:
433+
if valtype is bytes:
436434
try:
437435
return str(val, "ascii")
438436
except (TypeError, UnicodeDecodeError):
@@ -443,16 +441,15 @@ def _format(val, valtype, floatfmt, intfmt, missingval="", has_invisible=True):
443441
raw_val = _strip_ansi(val)
444442
formatted_val = format(float(raw_val), floatfmt)
445443
return val.replace(raw_val, formatted_val)
446-
else:
447-
if isinstance(val, str) and "," in val:
448-
val = val.replace(",", "") # handle thousands-separators
449-
return format(float(val), floatfmt)
444+
if isinstance(val, str) and "," in val:
445+
val = val.replace(",", "") # handle thousands-separators
446+
return format(float(val), floatfmt)
450447
else:
451448
return f"{val}"
452449

453450

454451
def _align_header(header, alignment, width, visible_width, is_multiline=False, width_fn=None):
455-
"Pad string header to width chars given known visible_width of the header."
452+
"""Pad string header to width chars given known visible_width of the header."""
456453
if is_multiline:
457454
header_lines = re.split(_multiline_codes, header)
458455
padded_lines = [_align_header(h, alignment, width, width_fn(h)) for h in header_lines]
@@ -462,12 +459,11 @@ def _align_header(header, alignment, width, visible_width, is_multiline=False, w
462459
width += ninvisible
463460
if alignment == "left":
464461
return _padright(width, header)
465-
elif alignment == "center":
462+
if alignment == "center":
466463
return _padboth(width, header)
467-
elif not alignment:
464+
if not alignment:
468465
return f"{header}"
469-
else:
470-
return _padleft(width, header)
466+
return _padleft(width, header)
471467

472468

473469
def _remove_separating_lines(rows):
@@ -480,12 +476,11 @@ def _remove_separating_lines(rows):
480476
else:
481477
sans_rows.append(row)
482478
return sans_rows, separating_lines
483-
else:
484-
return rows, None
479+
return rows, None
485480

486481

487482
def _bool(val):
488-
"A wrapper around standard bool() which doesn't throw on NumPy arrays"
483+
"""A wrapper around standard bool() which doesn't throw on NumPy arrays"""
489484
try:
490485
return bool(val)
491486
except ValueError: # val is likely to be a numpy array with many elements
@@ -506,7 +501,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"):
506501
index = None
507502
if hasattr(tabular_data, "keys") and hasattr(tabular_data, "values"):
508503
# dict-like and pandas.DataFrame?
509-
if hasattr(tabular_data.values, "__call__"):
504+
if callable(tabular_data.values):
510505
# likely a conventional dict
511506
keys = tabular_data.keys()
512507
try:
@@ -541,7 +536,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"):
541536
if headers == "keys" and not rows:
542537
# an empty table (issue #81)
543538
headers = []
544-
elif headers == "keys" and hasattr(tabular_data, "dtype") and getattr(tabular_data.dtype, "names"):
539+
elif headers == "keys" and hasattr(tabular_data, "dtype") and tabular_data.dtype.names:
545540
# numpy record array
546541
headers = tabular_data.dtype.names
547542
elif headers == "keys" and len(rows) > 0 and isinstance(rows[0], tuple) and hasattr(rows[0], "_fields"):
@@ -752,7 +747,7 @@ def tabulate(
752747
for idx, align in enumerate(colalign):
753748
if not idx < len(aligns):
754749
break
755-
elif align != "global":
750+
if align != "global":
756751
aligns[idx] = align
757752
minwidths = [width_fn(h) + min_padding for h in headers] if headers else [0] * len(cols)
758753
aligns_copy = aligns.copy()
@@ -788,7 +783,7 @@ def tabulate(
788783
hidx = headers_pad + idx
789784
if not hidx < len(aligns_headers):
790785
break
791-
elif align == "same" and hidx < len(aligns): # same as column align
786+
if align == "same" and hidx < len(aligns): # same as column align
792787
aligns_headers[hidx] = aligns[hidx]
793788
elif align != "global":
794789
aligns_headers[hidx] = align
@@ -816,15 +811,13 @@ def _expand_numparse(disable_numparse, column_count):
816811
for index in disable_numparse:
817812
numparses[index] = False
818813
return numparses
819-
else:
820-
return [not disable_numparse] * column_count
814+
return [not disable_numparse] * column_count
821815

822816

823817
def _expand_iterable(original, num_desired, default):
824818
if isinstance(original, Iterable) and not isinstance(original, str):
825819
return original + [default] * (num_desired - len(original))
826-
else:
827-
return [default] * num_desired
820+
return [default] * num_desired
828821

829822

830823
def _pad_row(cells, padding):
@@ -834,8 +827,7 @@ def _pad_row(cells, padding):
834827
pad = " " * padding
835828
padded_cells = [pad + cell + pad for cell in cells]
836829
return padded_cells
837-
else:
838-
return cells
830+
return cells
839831

840832

841833
def _build_simple_row(padded_cells, rowfmt):
@@ -846,10 +838,9 @@ def _build_simple_row(padded_cells, rowfmt):
846838
def _build_row(padded_cells, colwidths, colaligns, rowfmt):
847839
if not rowfmt:
848840
return None
849-
if hasattr(rowfmt, "__call__"):
841+
if callable(rowfmt):
850842
return rowfmt(padded_cells, colwidths, colaligns)
851-
else:
852-
return _build_simple_row(padded_cells, rowfmt)
843+
return _build_simple_row(padded_cells, rowfmt)
853844

854845

855846
def _append_basic_row(lines, padded_cells, colwidths, colaligns, rowfmt, rowalign=None):
@@ -859,15 +850,14 @@ def _append_basic_row(lines, padded_cells, colwidths, colaligns, rowfmt, rowalig
859850

860851

861852
def _build_line(colwidths, colaligns, linefmt):
862-
"Return a string which represents a horizontal line."
853+
"""Return a string which represents a horizontal line."""
863854
if not linefmt:
864855
return None
865-
if hasattr(linefmt, "__call__"):
856+
if callable(linefmt):
866857
return linefmt(colwidths, colaligns)
867-
else:
868-
begin, fill, sep, end = linefmt
869-
cells = [fill * w for w in colwidths]
870-
return _build_simple_row(cells, (begin, sep, end))
858+
begin, fill, sep, end = linefmt
859+
cells = [fill * w for w in colwidths]
860+
return _build_simple_row(cells, (begin, sep, end))
871861

872862

873863
def _append_line(lines, colwidths, colaligns, linefmt):
@@ -921,5 +911,5 @@ def _format_table(fmt, headers, headersaligns, rows, colwidths, colaligns, is_mu
921911
if headers or rows:
922912
output = "\n".join(lines)
923913
return output
924-
else: # a completely empty table
925-
return ""
914+
# a completely empty table
915+
return ""

0 commit comments

Comments
 (0)