|
| 1 | +# Copyright (C) 2018-2025 Intel Corporation |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +import os |
| 5 | +import re |
| 6 | +import sys |
| 7 | +from pathlib import Path |
| 8 | + |
| 9 | +def run_in_ci(): |
| 10 | + if "CI" in os.environ and os.environ["CI"].lower() == "true": |
| 11 | + return True |
| 12 | + |
| 13 | + if "TF_BUILD" in os.environ and len(os.environ["TF_BUILD"]): |
| 14 | + return True |
| 15 | + |
| 16 | + if "JENKINS_URL" in os.environ and len(os.environ["JENKINS_URL"]): |
| 17 | + return True |
| 18 | + |
| 19 | + return False |
| 20 | + |
| 21 | + |
| 22 | +if not run_in_ci() and not '--manual' in sys.argv: |
| 23 | + # execute check only in CI when the code is productized |
| 24 | + exit(0) |
| 25 | + |
| 26 | +if len(sys.argv) < 3: |
| 27 | + error_message = "Run in the following format: check_supported_ops.py path/to/ops supported_ops.md [--manual]\n" |
| 28 | + error_message += " --manual - script originated to run in CI, use this flag to run it manually" |
| 29 | + raise Exception(error_message) |
| 30 | + |
| 31 | +lookup_path = Path(sys.argv[1]) |
| 32 | +supported_ops_doc = sys.argv[2] |
| 33 | + |
| 34 | +if not lookup_path.exists() or not lookup_path.is_dir(): |
| 35 | + raise Exception(f"Argument \'{lookup_path}\' isn\'t a valid path to sources") |
| 36 | + |
| 37 | +files = [] |
| 38 | +# Looking for source files |
| 39 | +for path in lookup_path.rglob('*.cpp'): |
| 40 | + files.append(path) |
| 41 | +for path in lookup_path.rglob('*.hpp'): |
| 42 | + files.append(path) |
| 43 | + |
| 44 | +# MACRO Op Name Opset Impl Domain |
| 45 | +op_regex = re.compile(r'ONNX_OP(_M)?\("([a-z0-9_]+)",\s+([^\)\}]+)[\)\}]?,\s+([a-z0-9_:]+)(,\s+[^\)]+)?\);', re.IGNORECASE) |
| 46 | + |
| 47 | +ops = {} |
| 48 | + |
| 49 | +known_domains = { |
| 50 | + "":"", |
| 51 | + "OPENVINO_ONNX_DOMAIN":"org.openvinotoolkit", |
| 52 | + "MICROSOFT_DOMAIN":"com.microsoft", |
| 53 | + "PYTORCH_ATEN_DOMAIN":"org.pytorch.aten", |
| 54 | + "MMDEPLOY_DOMAIN":"mmdeploy" |
| 55 | +} |
| 56 | + |
| 57 | +hdr = "" |
| 58 | +with open(supported_ops_doc, 'rt') as src: |
| 59 | + table_line = 0 |
| 60 | + for line in src: |
| 61 | + if table_line < 2: |
| 62 | + hdr += line |
| 63 | + if line.count('|') == 6: |
| 64 | + table_line += 1 |
| 65 | + if table_line > 2: |
| 66 | + row = [cell.strip() for cell in line.split('|')] # Split line by "|" delimeter and remove spaces |
| 67 | + domain = row[1] |
| 68 | + if not domain in ops: |
| 69 | + ops[domain] = {} |
| 70 | + opname = row[2] |
| 71 | + defined = [] |
| 72 | + for item in row[4].split(', '): |
| 73 | + val = 1 |
| 74 | + try: |
| 75 | + val = int(item) |
| 76 | + except: |
| 77 | + continue |
| 78 | + defined.append(val) |
| 79 | + if not opname in ops[domain]: |
| 80 | + ops[domain][opname] = {'supported':[], 'defined': defined, 'limitations':row[5]} |
| 81 | + |
| 82 | +documentation_errors = [] |
| 83 | + |
| 84 | +for file_path in files: |
| 85 | + with open(file_path.as_posix(), "r") as src: |
| 86 | + reg_macro = None |
| 87 | + for line in src: |
| 88 | + # Multiline registration |
| 89 | + if 'ONNX_OP' in line: |
| 90 | + reg_macro = "" |
| 91 | + if not reg_macro is None: |
| 92 | + reg_macro += line |
| 93 | + else: |
| 94 | + continue |
| 95 | + if not ');' in line: |
| 96 | + continue |
| 97 | + # Registration macro has been found, trying parse it |
| 98 | + m = op_regex.search(reg_macro) |
| 99 | + if m is None: |
| 100 | + documentation_errors.append(f"Registration in file {file_path.as_posix()} is corrupted {reg_macro}, please check correctness") |
| 101 | + if ');' in line: reg_macro = None |
| 102 | + continue |
| 103 | + domain = m.group(5)[2:].strip() if not m.group(5) is None else "" |
| 104 | + if not domain in known_domains: |
| 105 | + documentation_errors.append(f"Unknown domain found in file {file_path.as_posix()} with identifier {domain}, please modify check_supported_ops.py if needed") |
| 106 | + if ');' in line: reg_macro = None |
| 107 | + continue |
| 108 | + domain = known_domains[domain] |
| 109 | + opname = m.group(2) |
| 110 | + opset = m.group(3) |
| 111 | + if not domain in ops: |
| 112 | + documentation_errors.append(f"Domain {domain} is missing in a list of documented operations supported_ops.md, update it by adding operation description") |
| 113 | + if ');' in line: reg_macro = None |
| 114 | + continue |
| 115 | + if not opname in ops[domain]: |
| 116 | + documentation_errors.append(f"Operation {domain if domain=='' else domain + '.'}{opname} is missing in a list of documented operations supported_ops.md, update it by adding operation description") |
| 117 | + if ');' in line: reg_macro = None |
| 118 | + continue |
| 119 | + if opset.startswith('OPSET_SINCE'): |
| 120 | + ops[domain][opname]['supported'].append(int(opset[12:])) |
| 121 | + elif opset.startswith('OPSET_IN'): |
| 122 | + ops[domain][opname]['supported'].append(int(opset[9:])) |
| 123 | + elif opset.startswith('OPSET_RANGE'): |
| 124 | + ops[domain][opname]['supported'].append(int(opset[12:].split(',')[0])) |
| 125 | + elif opset.startswith('{'): |
| 126 | + ops[domain][opname]['supported'].append(int(opset[1:].split(',')[0])) |
| 127 | + else: |
| 128 | + documentation_errors.append(f"Domain {domain} is missing in a list of documented operations supported_ops.md, update it by adding operation description") |
| 129 | + if ');' in line: reg_macro = None |
| 130 | + continue |
| 131 | + if ');' in line: |
| 132 | + reg_macro = None |
| 133 | + |
| 134 | +if len(documentation_errors) > 0: |
| 135 | + for errstr in documentation_errors: |
| 136 | + print('[ONNX Frontend] ' + errstr) |
| 137 | + raise Exception('[ONNX Frontend] failed: due to documentation errors') |
| 138 | + |
| 139 | +with open(supported_ops_doc, 'wt') as dst: |
| 140 | + dst.write(hdr) |
| 141 | + for domain, ops in ops.items(): |
| 142 | + for op_name in sorted(list(ops.keys())): |
| 143 | + data = ops[op_name] |
| 144 | + min_opset = data['defined'][-1] if len(data['defined']) > 0 else 1 |
| 145 | + if min_opset in data['supported']: |
| 146 | + min_opset = 1 |
| 147 | + dst.write(f"|{domain:<24}|{op_name:<56}|{', '.join([str(max(i, min_opset)) for i in sorted(data['supported'], reverse=True)]):<24}|{', '.join([str(i) for i in data['defined']]):<32}|{data['limitations']:<32}|\n") |
| 148 | + |
| 149 | +print("Data collected and stored") |
0 commit comments