forked from FZU-AV-CR/xsdflatten
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathxsdflatten.py
More file actions
executable file
·233 lines (195 loc) · 8.31 KB
/
xsdflatten.py
File metadata and controls
executable file
·233 lines (195 loc) · 8.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
#!/usr/bin/env python3
import sys
import re
import copy
import os
import argparse
from lxml import etree
# Optional can be used later on if needed
#def is_element(node):
#"""Check if node is an XML element (skip comments, PIs, text)."""
#return isinstance(node, etree._Element)
def get_includes_from_file(filename):
"""Return list of included XSD files from a file."""
pattern = re.compile('(<xs:include schemaLocation)')
try:
with open(filename, 'r', encoding='utf-8') as f:
lines = [line.strip() for line in f.readlines()]
except UnicodeDecodeError:
# Fallback to default encoding if UTF-8 fails
with open(filename, 'r') as f:
lines = [line.strip() for line in f.readlines()]
includes = [line.split('=')[1].split('"')[1] for line in lines if pattern.match(line)]
# sanity check
for inc in includes:
if not inc.endswith('.xsd'):
pass
return includes
def get_includes_recurse(filename, include_set):
"""Recursively collect all included XSD files."""
includes = get_includes_from_file(filename)
base_dir = os.path.dirname(os.path.abspath(filename))
absolute_includes = []
for inc in includes:
if not os.path.isabs(inc):
absolute_inc = os.path.join(base_dir, inc)
else:
absolute_inc = inc
absolute_includes.append(absolute_inc)
include_set.add(absolute_inc)
for inc in absolute_includes:
get_includes_recurse(inc, include_set)
def get_xml_tree_from_file(filename):
"""Parse XML file and return root element."""
tree = etree.parse(filename)
return tree.getroot()
def remove_includes_and_imports(root):
"""Remove xs:include and xs:import elements from root."""
# Find and remove the includes and imports
includes = root.findall('xs:include', root.nsmap)
for inc in includes:
root.remove(inc)
# Also remove imports to avoid duplicates later
imports = root.findall('xs:import', root.nsmap)
for imp in imports:
root.remove(imp)
return root
def collect_imports(root, import_set):
"""Collect all xs:import statements and add to import_set."""
# Collect all import statements
imports = root.findall('xs:import', root.nsmap)
for imp in imports:
namespace = imp.get('namespace')
schema_location = imp.get('schemaLocation')
if namespace and schema_location:
import_set.add((namespace, schema_location))
elif namespace:
# Import without schemaLocation (like XML namespace)
import_set.add((namespace, None))
def add_elements(target_root, source_root, processed_types):
"""Add elements and types from source_root to target_root, skipping duplicates."""
# Add elements from source to target, skipping duplicates
for child in source_root:
# Safe tag extraction: handle callable/bindings + ensure str
if not hasattr(child, 'tag'):
target_root.append(copy.deepcopy(child))
continue
raw_tag = child.tag
tag_name = raw_tag() if callable(raw_tag) else raw_tag
if not isinstance(tag_name, str):
tag_name = str(tag_name)
if tag_name.endswith('complexType') or tag_name.endswith('simpleType'):
type_name = child.get('name')
if type_name and type_name in processed_types:
continue # Skip duplicate type
if type_name:
processed_types.add(type_name)
elif tag_name.endswith('element'):
element_name = child.get('name')
if element_name and f'element:{element_name}' in processed_types:
continue # Skip duplicate element
if element_name:
processed_types.add(f'element:{element_name}')
target_root.append(copy.deepcopy(child))
def flatten_file(filename):
"""Flatten XSD file by merging includes and imports into a single file."""
include_set = set()
import_set = set()
processed_types = set() # Track processed type names to avoid duplicates
get_includes_recurse(filename, include_set)
# Get the main document
root = get_xml_tree_from_file(filename)
# Collect all imports from main and included files
collect_imports(root, import_set)
for inc_file in include_set:
inc_root = get_xml_tree_from_file(inc_file)
collect_imports(inc_root, import_set)
# Remove includes and imports from main document
root = remove_includes_and_imports(root)
# Note: We rely on import statements to provide namespace context
# Add all unique imports at the beginning (after schema element)
schema_ns = '{http://www.w3.org/2001/XMLSchema}'
insert_position = 0
for namespace, schema_location in sorted(import_set):
import_elem = etree.Element(f'{schema_ns}import')
import_elem.set('namespace', namespace)
if schema_location:
import_elem.set('schemaLocation', schema_location)
# Add proper indentation and newline
import_elem.tail = '\n '
root.insert(insert_position, import_elem)
insert_position += 1
# Process main file types first to establish baseline
main_root_copy = copy.deepcopy(root)
for child in main_root_copy:
# Safe tag extraction (same as above)
if not hasattr(child, 'tag'):
continue
raw_tag = child.tag
tag_name = raw_tag() if callable(raw_tag) else raw_tag
if not isinstance(tag_name, str):
tag_name = str(tag_name)
if tag_name.endswith('complexType') or tag_name.endswith('simpleType'):
type_name = child.get('name')
if type_name:
processed_types.add(type_name)
elif tag_name.endswith('element'):
element_name = child.get('name')
if element_name:
processed_types.add(f'element:{element_name}')
# Merge in the elements of the includes without duplicates
for inc_file in include_set:
inc_root = get_xml_tree_from_file(inc_file)
inc_root = remove_includes_and_imports(inc_root)
root.append(etree.Comment('Imported from %s' % inc_file))
add_elements(root, inc_root, processed_types)
# Serialize the result
result = etree.tostring(root, pretty_print=True, encoding='unicode')
# Post-process to add missing namespace prefixes
if any(ns == 'http://www.opengis.net/gml/3.2' for ns, _ in import_set):
# Add gml namespace prefix to schema element
result = result.replace(
'<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema"',
'<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema" xmlns:gml="http://www.opengis.net/gml/3.2"'
)
return result
def main():
"""Command-line interface for flattening XSD files."""
parser = argparse.ArgumentParser(
description='Flatten XSD files by merging includes into a single file',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""Examples:
%(prog)s schema.xsd
%(prog)s --output flattened.xsd input_schema.xsd"""
)
parser.add_argument('input_file',
help='Input XSD file to flatten')
parser.add_argument('-o', '--output',
help='Output file (if not specified, prints to stdout)',
metavar='FILE')
try:
args = parser.parse_args()
except SystemExit:
# argparse calls sys.exit() on error, we catch it to provide custom behavior if needed
raise
# Validate input file exists
if not os.path.isfile(args.input_file):
parser.error(f"Input file '{args.input_file}' does not exist or is not a file")
# Validate input file is XSD
if not args.input_file.lower().endswith('.xsd'):
parser.error(f"Input file '{args.input_file}' does not have .xsd extension")
try:
flattened_content = flatten_file(args.input_file)
if args.output:
with open(args.output, 'w', encoding='utf-8') as f:
f.write(flattened_content)
else:
print(flattened_content)
except FileNotFoundError as e:
parser.error(f"File not found: {e}")
except etree.XMLSyntaxError as e:
parser.error(f"XML parsing error: {e}")
except Exception as e:
parser.error(f"Unexpected error: {e}")
if __name__ == "__main__":
main()