diff --git a/.gitignore b/.gitignore index 70391cd2..78eb851a 100644 --- a/.gitignore +++ b/.gitignore @@ -137,3 +137,12 @@ dmypy.json # Pyre type checker .pyre/ + +# VSCode / pylint +.vscode/ +./pylintrc + +# Local Folder to Store Possible Bugs +possible_bugs/ +patterns.txt + diff --git a/deployment/run.sh b/deployment/run.sh index c87fe06d..c90fb93c 100755 --- a/deployment/run.sh +++ b/deployment/run.sh @@ -47,6 +47,16 @@ run_groovy_from_source() { --language groovy --cast-numbers } +install_ts_nightly() { + npm install -g typescript@next +} + +run_typescript() { + python3.9 hephaestus.py -s $TIME_TO_RUN -t 0 -w $CORES --batch 30 -P \ + --language typescript --disable-use-site-variance \ + --error-filter-patterns patterns.txt +} + run_multiple_versions() { cd $CHECK_TYPE_SYSTEMS git pull @@ -68,13 +78,18 @@ then exit 0 fi -while getopts "hksagS" OPTION; do +while getopts "hkstagS" OPTION; do case $OPTION in k) simple_run ;; + t) + install_ts_nightly + run_typescript + ;; + s) run_from_source ;; @@ -94,12 +109,14 @@ while getopts "hksagS" OPTION; do h) echo "Usage:" echo "init.sh -k " + echo "init.sh -t " echo "init.sh -s " echo "init.sh -a " echo "init.sh -g " echo "init.sh -S " echo "" echo " -k Simple run" + echo " -t Install latest typescript nightly version" echo " -s Run from source" echo " -a Run multiple versions" echo " -g Simple run groovy" diff --git a/deployment/setup.sh b/deployment/setup.sh index 1f1ca3b6..152c9d22 100755 --- a/deployment/setup.sh +++ b/deployment/setup.sh @@ -77,6 +77,15 @@ install_groovy_from_source() { source $HOME/.bash_profile } +install_npm() { + curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.2/install.sh | bash + nvm install node +} + +install_typescript() { + npm install -g typescript@next +} + install_kotlin() { install_java sdk install kotlin @@ -170,7 +179,7 @@ then exit 0 fi -while getopts "hskagS" OPTION; do +while getopts "hsktagS" OPTION; do case $OPTION in k) @@ -180,6 +189,11 @@ while getopts "hskagS" OPTION; do add_run_script_to_path ;; + t) + install_npm + install_typescript + ;; + s) install_deps install_kotlin_from_source @@ -211,12 +225,14 @@ while getopts "hskagS" OPTION; do h) echo "Usage:" echo "init.sh -k " + echo "init.sh -t " echo "init.sh -s " echo "init.sh -a " echo "init.sh -g " echo "init.sh -S " echo "" echo " -k Install latest kotlin version" + echo " -t Install latest typescript nightly version" echo " -s Install kotlin from source" echo " -a Install all kotlin versions" echo " -g Install latest groovy version" diff --git a/hephaestus.py b/hephaestus.py index 268704ff..d9d7f0e0 100755 --- a/hephaestus.py +++ b/hephaestus.py @@ -17,9 +17,11 @@ from src.compilers.kotlin import KotlinCompiler from src.compilers.groovy import GroovyCompiler from src.compilers.java import JavaCompiler +from src.compilers.typescript import TypeScriptCompiler from src.translators.kotlin import KotlinTranslator from src.translators.groovy import GroovyTranslator from src.translators.java import JavaTranslator +from src.translators.typescript import TypeScriptTranslator from src.modules.processor import ProgramProcessor @@ -27,12 +29,14 @@ TRANSLATORS = { 'kotlin': KotlinTranslator, 'groovy': GroovyTranslator, - 'java': JavaTranslator + 'java': JavaTranslator, + 'typescript' : TypeScriptTranslator } COMPILERS = { 'kotlin': KotlinCompiler, 'groovy': GroovyCompiler, - 'java': JavaCompiler + 'java': JavaCompiler, + 'typescript': TypeScriptCompiler } STATS = { "Info": { diff --git a/pylintrc b/pylintrc deleted file mode 100644 index 21394a9f..00000000 --- a/pylintrc +++ /dev/null @@ -1,504 +0,0 @@ -[MASTER] - -# Specify a score threshold to be exceeded before program exits with error. -fail-under=9.5 - -# Add files or directories to the blacklist. They should be base names, not -# paths. -ignore=tests,pickle_it.py,setup.py,examples,.eggs - -# Add files or directories matching the regex patterns to the blacklist. The -# regex matches against base names, not paths. -ignore-patterns= - -# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the -# number of processors available to use. -jobs=1 - -# Control the amount of potential inferred values when inferring a single -# object. This can help the performance when dealing with large functions or -# complex, nested conditions. -limit-inference-results=100 - -# List of plugins (as comma separated values of python module names) to load, -# usually to register additional checkers. -load-plugins= - -# Pickle collected data for later comparisons. -persistent=yes - -# When enabled, pylint would attempt to guess common misconfiguration and emit -# user-friendly hints instead of false-positive error messages. -suggestion-mode=yes - -# Allow loading of arbitrary C extensions. Extensions are imported into the -# active Python interpreter and may run arbitrary code. -unsafe-load-any-extension=no - - -[MESSAGES CONTROL] - -# Only show warnings with the listed confidence levels. Leave empty to show -# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. -confidence= - -# Disable the message, report, category or checker with the given id(s). You -# can either give multiple identifiers separated by comma (,) or put this -# option multiple times (only on the command line, not in the configuration -# file where it should appear only once). You can also use "--disable=all" to -# disable everything first and then reenable specific checks. For example, if -# you want to run only the similarities checker, you can use "--disable=all -# --enable=similarities". If you want to run only the classes checker, but have -# no Warning level messages displayed, use "--disable=all --enable=classes -# --disable=W". -disable=missing-function-docstring, - missing-class-docstring, - missing-module-docstring, - too-many-public-methods - -# Enable the message, report, category or checker with the given id(s). You can -# either give multiple identifier separated by comma (,) or put this option -# multiple time (only on the command line, not in the configuration file where -# it should appear only once). See also the "--disable" option for examples. -enable=c-extension-no-member - - -[REPORTS] - -# Python expression which should return a score less than or equal to 10. You -# have access to the variables 'error', 'warning', 'refactor', and 'convention' -# which contain the number of messages in each category, as well as 'statement' -# which is the total number of statements analyzed. This score is used by the -# global evaluation report (RP0004). -evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) - -# Template used to display messages. This is a python new-style format string -# used to format the message information. See doc for all details. -#msg-template= - -# Set the output format. Available formats are text, parseable, colorized, json -# and msvs (visual studio). You can also give a reporter class, e.g. -# mypackage.mymodule.MyReporterClass. -output-format=text - -# Tells whether to display a full report or only the messages. -reports=no - -# Activate the evaluation score. -score=yes - - -[REFACTORING] - -# Maximum number of nested blocks for function / method body -max-nested-blocks=5 - -# Complete name of functions that never returns. When checking for -# inconsistent-return-statements if a never returning function is called then -# it will be considered as an explicit return statement and no message will be -# printed. -never-returning-functions=sys.exit - - -[TYPECHECK] - -# List of decorators that produce context managers, such as -# contextlib.contextmanager. Add to this list to register other decorators that -# produce valid context managers. -contextmanager-decorators=contextlib.contextmanager - -# List of members which are set dynamically and missed by pylint inference -# system, and so shouldn't trigger E1101 when accessed. Python regular -# expressions are accepted. -generated-members= - -# Tells whether missing members accessed in mixin class should be ignored. A -# mixin class is detected if its name ends with "mixin" (case insensitive). -ignore-mixin-members=yes - -# Tells whether to warn about missing members when the owner of the attribute -# is inferred to be None. -ignore-none=yes - -# This flag controls whether pylint should warn about no-member and similar -# checks whenever an opaque object is returned when inferring. The inference -# can return multiple potential results while evaluating a Python object, but -# some branches might not be evaluated, which results in partial inference. In -# that case, it might be useful to still emit no-member and other checks for -# the rest of the inferred objects. -ignore-on-opaque-inference=yes - -# List of class names for which member attributes should not be checked (useful -# for classes with dynamically set attributes). This supports the use of -# qualified names. -ignored-classes=optparse.Values,thread._local,_thread._local - -# List of module names for which member attributes should not be checked -# (useful for modules/projects where namespaces are manipulated during runtime -# and thus existing member attributes cannot be deduced by static analysis). It -# supports qualified module names, as well as Unix pattern matching. -ignored-modules= - -# Show a hint with possible names when a member name was not found. The aspect -# of finding the hint is based on edit distance. -missing-member-hint=yes - -# The minimum edit distance a name should have in order to be considered a -# similar match for a missing member name. -missing-member-hint-distance=1 - -# The total number of similar names that should be taken in consideration when -# showing a hint for a missing member. -missing-member-max-choices=1 - -# List of decorators that change the signature of a decorated function. -signature-mutators= - - -[BASIC] - -# Naming style matching correct argument names. -argument-naming-style=snake_case - -# Regular expression matching correct argument names. Overrides argument- -# naming-style. -#argument-rgx= - -# Naming style matching correct attribute names. -attr-naming-style=snake_case - -# Regular expression matching correct attribute names. Overrides attr-naming- -# style. -#attr-rgx= - -# Bad variable names which should always be refused, separated by a comma. -bad-names=foo, - bar, - baz, - toto, - tutu, - tata - -# Bad variable names regexes, separated by a comma. If names match any regex, -# they will always be refused -bad-names-rgxs= - -# Naming style matching correct class attribute names. -class-attribute-naming-style=any - -# Regular expression matching correct class attribute names. Overrides class- -# attribute-naming-style. -#class-attribute-rgx= - -# Naming style matching correct class names. -class-naming-style=PascalCase - -# Regular expression matching correct class names. Overrides class-naming- -# style. -#class-rgx= - -# Naming style matching correct constant names. -const-naming-style=UPPER_CASE - -# Regular expression matching correct constant names. Overrides const-naming- -# style. -#const-rgx= - -# Minimum line length for functions/classes that require docstrings, shorter -# ones are exempt. -docstring-min-length=-1 - -# Naming style matching correct function names. -function-naming-style=snake_case - -# Regular expression matching correct function names. Overrides function- -# naming-style. -#function-rgx= - -# Good variable names which should always be accepted, separated by a comma. -good-names=i, - j, - k, - ex, - Run, - _ - -# Good variable names regexes, separated by a comma. If names match any regex, -# they will always be accepted -good-names-rgxs= - -# Include a hint for the correct naming format with invalid-name. -include-naming-hint=no - -# Naming style matching correct inline iteration names. -inlinevar-naming-style=any - -# Regular expression matching correct inline iteration names. Overrides -# inlinevar-naming-style. -#inlinevar-rgx= - -# Naming style matching correct method names. -method-naming-style=snake_case - -# Regular expression matching correct method names. Overrides method-naming- -# style. -#method-rgx= - -# Naming style matching correct module names. -module-naming-style=snake_case - -# Regular expression matching correct module names. Overrides module-naming- -# style. -#module-rgx= - -# Colon-delimited sets of names that determine each other's naming style when -# the name regexes allow several styles. -name-group= - -# Regular expression which should only match function or class names that do -# not require a docstring. -no-docstring-rgx=^_ - -# List of decorators that produce properties, such as abc.abstractproperty. Add -# to this list to register other decorators that produce valid properties. -# These decorators are taken in consideration only for invalid-name. -property-classes=abc.abstractproperty - -# Naming style matching correct variable names. -variable-naming-style=snake_case - -# Regular expression matching correct variable names. Overrides variable- -# naming-style. -variable-rgx=([a-z_][a-z0-9_]{2,30}$|c|ns|tp|op|e1|e2) - - -[SPELLING] - -# Limits count of emitted suggestions for spelling mistakes. -max-spelling-suggestions=4 - -# Spelling dictionary name. Available dictionaries: none. To make it work, -# install the python-enchant package. -spelling-dict= - -# List of comma separated words that should not be checked. -spelling-ignore-words= - -# A path to a file that contains the private dictionary; one word per line. -spelling-private-dict-file= - -# Tells whether to store unknown words to the private dictionary (see the -# --spelling-private-dict-file option) instead of raising a message. -spelling-store-unknown-words=no - - -[FORMAT] - -# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. -expected-line-ending-format= - -# Regexp for a line that is allowed to be longer than the limit. -ignore-long-lines=^\s*(# )??$ - -# Number of spaces of indent required inside a hanging or continued line. -indent-after-paren=4 - -# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 -# tab). -indent-string=' ' - -# Maximum number of characters on a single line. -max-line-length=100 - -# Maximum number of lines in a module. -max-module-lines=1000 - -# Allow the body of a class to be on the same line as the declaration if body -# contains single statement. -single-line-class-stmt=no - -# Allow the body of an if to be on the same line as the test if there is no -# else. -single-line-if-stmt=no - - -[SIMILARITIES] - -# Ignore comments when computing similarities. -ignore-comments=yes - -# Ignore docstrings when computing similarities. -ignore-docstrings=yes - -# Ignore imports when computing similarities. -ignore-imports=no - -# Minimum lines number of a similarity. -min-similarity-lines=4 - - -[STRING] - -# This flag controls whether inconsistent-quotes generates a warning when the -# character used as a quote delimiter is used inconsistently within a module. -check-quote-consistency=no - -# This flag controls whether the implicit-str-concat should generate a warning -# on implicit string concatenation in sequences defined over several lines. -check-str-concat-over-line-jumps=no - - -[LOGGING] - -# The type of string formatting that logging methods do. `old` means using % -# formatting, `new` is for `{}` formatting. -logging-format-style=old - -# Logging modules to check that the string format arguments are in logging -# function parameter format. -logging-modules=logging - - -[VARIABLES] - -# List of additional names supposed to be defined in builtins. Remember that -# you should avoid defining new builtins when possible. -additional-builtins= - -# Tells whether unused global variables should be treated as a violation. -allow-global-unused-variables=yes - -# List of strings which can identify a callback function by name. A callback -# name must start or end with one of those strings. -callbacks=cb_, - _cb - -# A regular expression matching the name of dummy variables (i.e. expected to -# not be used). -dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ - -# Argument names that match this expression will be ignored. Default to name -# with leading underscore. -ignored-argument-names=_.*|^ignored_|^unused_ - -# Tells whether we should check for unused import in __init__ files. -init-import=no - -# List of qualified module names which can have objects that can redefine -# builtins. -redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io - - -[MISCELLANEOUS] - -# List of note tags to take in consideration, separated by a comma. -notes=FIXME, - XXX, - TODO - -# Regular expression of note tags to take in consideration. -#notes-rgx= - - -[CLASSES] - -# List of method names used to declare (i.e. assign) instance attributes. -defining-attr-methods=__init__, - __new__, - setUp, - __post_init__ - -# List of member names, which should be excluded from the protected access -# warning. -exclude-protected=_asdict, - _fields, - _replace, - _source, - _make - -# List of valid names for the first argument in a class method. -valid-classmethod-first-arg=cls - -# List of valid names for the first argument in a metaclass class method. -valid-metaclass-classmethod-first-arg=cls - - -[DESIGN] - -# Maximum number of arguments for function / method. -max-args=9 - -# Maximum number of attributes for a class (see R0902). -max-attributes=10 - -# Maximum number of boolean expressions in an if statement (see R0916). -max-bool-expr=5 - -# Maximum number of branch for function / method body. -max-branches=12 - -# Maximum number of locals for function / method body. -max-locals=15 - -# Maximum number of parents for a class (see R0901). -max-parents=7 - -# Maximum number of public methods for a class (see R0904). -max-public-methods=20 - -# Maximum number of return / yield for function / method body. -max-returns=6 - -# Maximum number of statements in function / method body. -max-statements=50 - -# Minimum number of public methods for a class (see R0903). -min-public-methods=2 - - -[IMPORTS] - -# List of modules that can be imported at any level, not just the top level -# one. -allow-any-import-level= - -# Allow wildcard imports from modules that define __all__. -allow-wildcard-with-all=no - -# Analyse import fallback blocks. This can be used to support both Python 2 and -# 3 compatible code, which means that the block might have code that exists -# only in one or another interpreter, leading to false positives when analysed. -analyse-fallback-blocks=no - -# Deprecated modules which should not be used, separated by a comma. -deprecated-modules=optparse,tkinter.tix - -# Create a graph of external dependencies in the given file (report RP0402 must -# not be disabled). -ext-import-graph= - -# Create a graph of every (i.e. internal and external) dependencies in the -# given file (report RP0402 must not be disabled). -import-graph= - -# Create a graph of internal dependencies in the given file (report RP0402 must -# not be disabled). -int-import-graph= - -# Force import order to recognize a module as part of the standard -# compatibility libraries. -known-standard-library= - -# Force import order to recognize a module as part of a third party library. -known-third-party=enchant - -# Couples of modules and preferred modules, separated by a comma. -preferred-modules= - - -[EXCEPTIONS] - -# Exceptions that will emit a warning when being caught. Defaults to -# "BaseException, Exception". -overgeneral-exceptions=BaseException, - Exception diff --git a/src/args.py b/src/args.py index 094f176f..33f8f1b6 100644 --- a/src/args.py +++ b/src/args.py @@ -107,7 +107,7 @@ parser.add_argument( "--language", default="kotlin", - choices=['kotlin', 'groovy', 'java'], + choices=['kotlin', 'groovy', 'java', 'typescript'], help="Select specific language" ) parser.add_argument( @@ -236,6 +236,9 @@ def validate_args(args): if args.examine and not args.replay: sys.exit("You cannot use --examine option without the --replay option") + if args.language == "typescript" and not args.disable_use_site_variance: + sys.exit("\nTypeScript does not support use-site variance.\nRun Hephaestus again with the flag --disable-use-site-variance.") + def pre_process_args(args): # PRE-PROCESSING diff --git a/src/compilers/typescript.py b/src/compilers/typescript.py new file mode 100644 index 00000000..3354bde6 --- /dev/null +++ b/src/compilers/typescript.py @@ -0,0 +1,28 @@ +import re +import os +from src.compilers.base import BaseCompiler + + +class TypeScriptCompiler(BaseCompiler): + + ERROR_REGEX = re.compile( + r'[.\/]+(/[A-Za-z0-9_\-\/]+\.ts).*error.*(TS\d+): (.+)') + + CRASH_REGEX = re.compile(r'(.+)(\n(\s+at .+))+') + + def __init__(self, input_name, filter_patterns=None): + super().__init__(input_name, filter_patterns) + + @classmethod + def get_compiler_version(cls): + return ['tsc', '-v'] + + def get_compiler_cmd(self): + return ['tsc --target es2020 --pretty false', os.path.join( + self.input_name, '**', '*.ts')] + + def get_filename(self, match): + return match[0] + + def get_error_msg(self, match): + return f"{match[1]}{match[2]}" diff --git a/src/generators/generator.py b/src/generators/generator.py index 9ffc173d..b07626a6 100644 --- a/src/generators/generator.py +++ b/src/generators/generator.py @@ -28,7 +28,7 @@ from src.generators import generators as gens from src.generators import utils as gu from src.generators.config import cfg -from src.ir import ast, types as tp, type_utils as tu, kotlin_types as kt +from src.ir import ast, types as tp, type_utils as tu, kotlin_types as kt, typescript_types as tst from src.ir.context import Context from src.ir.builtins import BuiltinFactory from src.ir import BUILTIN_FACTORIES @@ -45,7 +45,7 @@ def __init__(self, self.language = language self.logger: Logger = logger self.context: Context = None - self.bt_factory: BuiltinFactory = BUILTIN_FACTORIES[language] + self.bt_factory: BuiltinFactory = BUILTIN_FACTORIES[language]() self.depth = 1 self._vars_in_context = defaultdict(lambda: 0) self._new_from_class = None @@ -94,7 +94,7 @@ def generate(self, context=None) -> ast.Program: cfg.limits.max_top_level): self.gen_top_level_declaration() self.generate_main_func() - return ast.Program(self.context, self.language) + return ast.Program(self.context, self.language, self.bt_factory) def gen_top_level_declaration(self): """Generate a top-level declaration and add it in the context. @@ -110,12 +110,13 @@ def gen_top_level_declaration(self): declarations. """ candidates = [ - self.gen_variable_decl, - self.gen_class_decl, - self.gen_func_decl, + lambda gen: gen.gen_variable_decl(), + lambda gen: gen.gen_class_decl(), + lambda gen: gen.gen_func_decl(), ] + candidates.extend(self.bt_factory.get_decl_candidates()) gen_func = ut.random.choice(candidates) - gen_func() + gen_func(self) def generate_main_func(self) -> ast.FunctionDeclaration: """Generate the main function. @@ -215,7 +216,7 @@ def gen_func_decl(self, Returns: A function declaration node. """ - func_name = func_name or gu.gen_identifier('lower') + func_name = func_name or ut.random.identifier('lower') initial_namespace = self.namespace if namespace: @@ -273,7 +274,8 @@ def gen_func_decl(self, if ( ut.random.bool(prob=0.25) or self.language == 'java' or - self.language == 'groovy' and is_interface + self.language == 'groovy' and is_interface or + self.language == 'typescript' and is_interface ) else self._gen_func_params_with_default() ) @@ -338,7 +340,7 @@ def gen_param_decl(self, etype=None) -> ast.ParameterDeclaration: Args: etype: Parameter type. """ - name = gu.gen_identifier('lower') + name = ut.random.identifier('lower') if etype and etype.is_wildcard(): bound = etype.get_bound_rec() param_type = bound or self.select_type(exclude_covariants=True) @@ -371,7 +373,7 @@ def gen_class_decl(self, Returns: A class declaration node. """ - class_name = class_name or gu.gen_identifier('capitalize') + class_name = class_name or ut.random.identifier('capitalize') initial_namespace = self.namespace self.namespace += (class_name,) initial_depth = self.depth @@ -536,22 +538,22 @@ def _add_node_to_class(self, cls, node): def _add_node_to_parent(self, parent_namespace, node): node_type = { - ast.FunctionDeclaration: self.context.add_func, - ast.ClassDeclaration: self.context.add_class, - ast.VariableDeclaration: self.context.add_var, - ast.FieldDeclaration: self.context.add_var, - ast.ParameterDeclaration: self.context.add_var, - ast.Lambda: self.context.add_lambda, + ast.FunctionDeclaration: lambda gen, p, n, nd: gen.context.add_func(p, n, nd), + ast.ClassDeclaration: lambda gen, p, n, nd: gen.context.add_class(p, n, nd), + ast.VariableDeclaration: lambda gen, p, n, nd: gen.context.add_var(p, n, nd), + ast.FieldDeclaration: lambda gen, p, n, nd: gen.context.add_var(p, n, nd), + ast.ParameterDeclaration: lambda gen, p, n, nd: gen.context.add_var(p, n, nd), + ast.Lambda: lambda gen, p, n, nd: gen.context.add_lambda(p, n, nd), } + node_type.update(self.bt_factory.update_add_node_to_parent()) if parent_namespace == ast.GLOBAL_NAMESPACE: - node_type[type(node)](parent_namespace, node.name, node) + node_type[type(node)](self, parent_namespace, node.name, node) return parent = self.context.get_decl(parent_namespace[:-1], parent_namespace[-1]) if parent and isinstance(parent, ast.ClassDeclaration): self._add_node_to_class(parent, node) - - node_type[type(node)](parent_namespace, node.name, node) + node_type[type(node)](self, parent_namespace, node.name, node) # And @@ -710,7 +712,7 @@ def _gen_func_from_existing(self, def _gen_type_params_from_existing(self, func: ast.FunctionDeclaration, type_var_map - ) -> (List[tp.TypeParameter], tu.TypeVarMap): + ) -> Tuple[List[tp.TypeParameter], tu.TypeVarMap]: """Gen type parameters for a function that overrides a parameterized function. @@ -773,7 +775,7 @@ def gen_field_decl(self, etype=None, etype: Field type. class_is_final: Is the class final. """ - name = gu.gen_identifier('lower') + name = ut.random.identifier('lower') can_override = not class_is_final and ut.random.bool() is_final = ut.random.bool() field_type = etype or self.select_type(exclude_contravariants=True, @@ -813,7 +815,7 @@ def gen_variable_decl(self, vtype = var_type.get_bound_rec() if var_type.is_wildcard() else \ var_type var_decl = ast.VariableDeclaration( - gu.gen_identifier('lower'), + ut.random.identifier('lower'), expr=expr, is_final=is_final, var_type=vtype, @@ -821,6 +823,7 @@ def gen_variable_decl(self, self._add_node_to_parent(self.namespace, var_decl) return var_decl + ##### Expressions ##### def _get_class_attributes(self, class_decl, attr_name): @@ -1533,7 +1536,7 @@ def _gen_func_call(self, if not param.vararg: arg = self.generate_expr(expr_type, only_leaves, gen_bottom=gen_bottom) - if param.default: + if param.default and self.language != 'typescript': if self.language == 'kotlin' and ut.random.bool(): # Randomly skip some default arguments. args.append(ast.CallArgument(arg, name=param.name)) @@ -1890,7 +1893,9 @@ def gen_fun_call(etype): self.bt_factory.get_array_type().name: ( lambda x: self.gen_array_expr(x, only_leaves, subtype=subtype) ), + self.bt_factory.get_null_type().name: lambda x: ast.Null } + constant_candidates.update(self.bt_factory.get_constant_candidates(constant_candidates)) binary_ops = { self.bt_factory.get_boolean_type(): [ lambda x: self.gen_logical_expr(x, only_leaves), @@ -1945,7 +1950,8 @@ def get_types(self, exclude_covariants=False, exclude_contravariants=False, exclude_type_vars=False, - exclude_function_types=False) -> List[tp.Type]: + exclude_function_types=False, + exclude_native_compound_types=False) -> List[tp.Type]: """Get all available types. Including user-defined types, built-ins, and function types. @@ -1959,6 +1965,7 @@ def get_types(self, exclude_contravariants: exclude contravariant type parameters. exclude_type_vars: exclude type variables. exclude_function_types: exclude function types. + exclude_native_compound_types: exclude native compound types. Returns: A list of available types. @@ -1976,7 +1983,6 @@ def get_types(self, if exclude_contravariants and variance == tp.Contravariant: continue type_params.append(t_param) - if type_params and ut.random.bool(): return type_params @@ -1988,16 +1994,21 @@ def get_types(self, t for t in builtins if t.name != self.bt_factory.get_array_type().name ] + + compound_types = (self.bt_factory.get_compound_types(self) + if not exclude_native_compound_types + else []) if exclude_function_types: - return usr_types + builtins - return usr_types + builtins + self.function_types + return usr_types + builtins + compound_types + return usr_types + builtins + compound_types + self.function_types def select_type(self, ret_types=True, exclude_arrays=False, exclude_covariants=False, exclude_contravariants=False, - exclude_function_types=False) -> tp.Type: + exclude_function_types=False, + exclude_native_compound_types=False) -> tp.Type: """Select a type from the all available types. It will always instantiating type constructors to parameterized types. @@ -2008,8 +2019,8 @@ def select_type(self, exclude_arrays: exclude array types. exclude_covariants: exclude covariant type parameters. exclude_contravariants: exclude contravariant type parameters. - exclude_type_vars: exclude type variables. exclude_function_types: exclude function types. + exclude_native_compound_types: exclude native compound types. Returns: Returns a type. @@ -2018,7 +2029,8 @@ def select_type(self, exclude_arrays=exclude_arrays, exclude_covariants=exclude_covariants, exclude_contravariants=exclude_contravariants, - exclude_function_types=exclude_function_types) + exclude_function_types=exclude_function_types, + exclude_native_compound_types=exclude_native_compound_types) stype = ut.random.choice(types) if stype.is_type_constructor(): exclude_type_vars = stype.name == self.bt_factory.get_array_type().name @@ -2027,7 +2039,8 @@ def select_type(self, exclude_covariants=True, exclude_contravariants=True, exclude_type_vars=exclude_type_vars, - exclude_function_types=exclude_function_types), + exclude_function_types=exclude_function_types, + exclude_native_compound_types=exclude_native_compound_types), enable_pecs=self.enable_pecs, disable_variance_functions=self.disable_variance_functions, variance_choices={} @@ -2793,7 +2806,7 @@ def _gen_matching_class(self, declaration (field or function). """ initial_namespace = self.namespace - class_name = gu.gen_identifier('capitalize') + class_name = ut.random.identifier('capitalize') type_params = None # Get return type, type_var_map, and flag for wildcards @@ -2886,8 +2899,8 @@ def _create_type_params_from_etype(self, etype: tp.Type): type_params[0].variance = tp.Invariant return type_params, {etype: type_params[0]}, True - # the given type is parameterized - assert isinstance(etype, (tp.ParameterizedType, tp.WildCardType)) + # the given type is compound + assert etype.is_compound() or etype.is_wildcard() type_vars = etype.get_type_variables(self.bt_factory) type_params = self.gen_type_params( len(type_vars), with_variance=self.language == 'kotlin') diff --git a/src/generators/generators.py b/src/generators/generators.py index 32661cd0..5090f7db 100644 --- a/src/generators/generators.py +++ b/src/generators/generators.py @@ -15,7 +15,7 @@ def gen_string_constant(expr_type=None) -> ast.StringConstant: """Generate a string constant. """ - return ast.StringConstant(gu.gen_identifier()) + return ast.StringConstant(utils.random.identifier()) # pylint: disable=unused-argument diff --git a/src/generators/utils.py b/src/generators/utils.py index 08208131..3fe3a537 100644 --- a/src/generators/utils.py +++ b/src/generators/utils.py @@ -98,23 +98,3 @@ def init_variance_choices(type_var_map: tu.TypeVarMap) -> tu.VarianceChoices: type_var = type_var.bound variance_choices[type_var] = (False, False) return variance_choices - - -def gen_identifier(ident_type:str=None) -> str: - """Generate an identifier name. - - Args: - ident_type: None or 'capitalize' or 'lower' - - Raises: - AssertionError: Raises an AssertionError if the ident_type is neither - 'capitalize' nor 'lower'. - """ - word = ut.random.word() - if ident_type is None: - return word - if ident_type == 'lower': - return word.lower() - if ident_type == 'capitalize': - return word.capitalize() - raise AssertionError("ident_type should be 'capitalize' or 'lower'") diff --git a/src/ir/__init__.py b/src/ir/__init__.py index f87d4e09..2e8bf1bf 100644 --- a/src/ir/__init__.py +++ b/src/ir/__init__.py @@ -1,9 +1,11 @@ from src.ir.kotlin_types import KotlinBuiltinFactory from src.ir.groovy_types import GroovyBuiltinFactory from src.ir.java_types import JavaBuiltinFactory +from src.ir.typescript_types import TypeScriptBuiltinFactory BUILTIN_FACTORIES = { - "kotlin": KotlinBuiltinFactory(), - "groovy": GroovyBuiltinFactory(), - "java": JavaBuiltinFactory() + "kotlin": KotlinBuiltinFactory, + "groovy": GroovyBuiltinFactory, + "java": JavaBuiltinFactory, + "typescript": TypeScriptBuiltinFactory } diff --git a/src/ir/ast.py b/src/ir/ast.py index f3cb43b3..ea331b16 100644 --- a/src/ir/ast.py +++ b/src/ir/ast.py @@ -5,7 +5,6 @@ import src.ir.type_utils as tu import src.ir.types as types from src import utils -from src.ir import BUILTIN_FACTORIES from src.ir.builtins import BuiltinFactory, FunctionType from src.ir.node import Node @@ -32,10 +31,10 @@ class Expr(Node): class Program(Node): # Set default value to kotlin for backward compatibility - def __init__(self, context, language): + def __init__(self, context, language, bt_factory): self.context = context self.language = language - self.bt_factory: BuiltinFactory = BUILTIN_FACTORIES[language] + self.bt_factory: BuiltinFactory = bt_factory def children(self): return self.context.get_declarations(GLOBAL_NAMESPACE, @@ -860,6 +859,15 @@ def is_bottom(self): return True +class NullConstant(Constant): + def __init__(self): + super().__init__("null") + + def is_equal(self, other): + return isinstance(other, NullConstant) + +Null = NullConstant() + class IntegerConstant(Constant): # TODO: Support Hex Integer literals, binary integer literals? def __init__(self, literal: int, integer_type): @@ -1076,6 +1084,10 @@ class LogicalExpr(BinaryOp): "java": [ Operator('&&'), Operator('||') + ], + "typescript": [ + Operator('&&'), + Operator('||') ] } @@ -1103,6 +1115,10 @@ class EqualityExpr(BinaryOp): "java": [ Operator('=='), Operator('=', is_not=True) + ], + "typescript": [ + Operator('==='), + Operator('==', is_not= True) ] } @@ -1132,6 +1148,12 @@ class ComparisonExpr(BinaryOp): Operator('>='), Operator('<'), Operator('<=') + ], + "typescript": [ + Operator('>'), + Operator('>='), + Operator('<'), + Operator('<=') ] } @@ -1161,6 +1183,12 @@ class ArithExpr(BinaryOp): Operator('-'), Operator('/'), Operator('*') + ], + "typescript": [ + Operator('+'), + Operator('-'), + Operator('/'), + Operator('*') ] } diff --git a/src/ir/builtins.py b/src/ir/builtins.py index dd41b356..88777488 100644 --- a/src/ir/builtins.py +++ b/src/ir/builtins.py @@ -77,6 +77,10 @@ def get_array_type(self): def get_function_type(self, nr_parameters=0): pass + @abstractmethod + def get_null_type(self): + pass + def get_non_nothing_types(self): return [ self.get_any_type(), @@ -92,7 +96,7 @@ def get_non_nothing_types(self): self.get_boolean_type(), self.get_char_type(), self.get_string_type(), - self.get_array_type() + self.get_array_type(), ] def get_number_types(self): @@ -107,12 +111,50 @@ def get_number_types(self): self.get_big_integer_type(), ] + def get_decl_candidates(self): + """ Overwrite this method to return a list + with language-specific AST declaration nodes. + + See implementation in typescript_types.py and + TS-specific AST nodes in typescript_ast.py + + """ + return [] + + def update_add_node_to_parent(self): + """ Overwrite this to update the dict 'node_type' + on src.generators.generator._add_node_to_parent + with the respective + key-value pair for the language-specific AST nodes. + + See implementation in typescript_types.py + + """ + return {} + def get_function_types(self, max_parameters): return [self.get_function_type(i) for i in range(0, max_parameters+1)] def get_nothing(self): raise NotImplementedError + def get_compound_types(self, gen_object): + """ A type is considered compound if it can consist + of other types. This function is used to add a lanuage's + native compound types to the generator. + + Eg. A TypeScript Union Type: string | myClass + + """ + return [] + + def get_constant_candidates(self, constants): + """ Overwrite this function to update the generator + constants with language-specific. + + """ + return {} + class AnyType(Builtin): def __init__(self, name="Any"): diff --git a/src/ir/decorators.py b/src/ir/decorators.py new file mode 100644 index 00000000..75bae264 --- /dev/null +++ b/src/ir/decorators.py @@ -0,0 +1,4 @@ +def two_way_subtyping(is_subtype): + def inner(self, other): + return is_subtype(self, other) or other.two_way_subtyping(self) + return inner diff --git a/src/ir/groovy_types.py b/src/ir/groovy_types.py index 80923c5d..288ad781 100644 --- a/src/ir/groovy_types.py +++ b/src/ir/groovy_types.py @@ -72,6 +72,9 @@ def get_primitive_types(self): BooleanType(primitive=True) ] + def get_null_type(self): + raise Exception("Groovy does not support null types") + def get_non_nothing_types(self): return super().get_non_nothing_types() + self.get_primitive_types() diff --git a/src/ir/java_types.py b/src/ir/java_types.py index 7a95b734..2878f545 100644 --- a/src/ir/java_types.py +++ b/src/ir/java_types.py @@ -62,6 +62,9 @@ def get_big_integer_type(self): def get_function_type(self, nr_parameters=0): return FunctionType(nr_parameters) + def get_null_type(self): + raise Exception("Java does not support null types") + def get_primitive_types(self): return [ ByteType(primitive=True), diff --git a/src/ir/kotlin_types.py b/src/ir/kotlin_types.py index a7383081..367f4734 100644 --- a/src/ir/kotlin_types.py +++ b/src/ir/kotlin_types.py @@ -63,6 +63,9 @@ def get_function_type(self, nr_parameters=0): def get_nothing(self): return NothingType() + def get_null_type(self): + return NullType() + def get_non_nothing_types(self): types = super().get_non_nothing_types() types.extend([ @@ -123,6 +126,17 @@ def get_builtin_type(self): return bt.Number +class NullType(AnyType): + def __init__(self, name="null"): + super().__init__(name) + + def box_type(self): + return NullType(self.name) + + def get_name(self): + return 'null' + + class IntegerType(NumberType): def __init__(self, name="Int"): super().__init__(name) diff --git a/src/ir/type_utils.py b/src/ir/type_utils.py index b4556724..69442255 100644 --- a/src/ir/type_utils.py +++ b/src/ir/type_utils.py @@ -1039,6 +1039,15 @@ def unify_types(t1: tp.Type, t2: tp.Type, factory, class A class B : A() """ + if t2.is_compound() and not t2.is_parameterized(): + return t2.unify_types(t1, factory, same_type) + elif t1.is_compound() and t2.is_type_var(): + bound = t2.get_bound_rec(factory) + if bound is None or t1.is_subtype(bound): + return {t2: t1} + else: + return {} + if same_type and type(t1) != type(t2): return {} @@ -1100,10 +1109,8 @@ class B : A() if not _update_type_var_map(type_var_map, t_var, t_arg1): return {} continue - is_parameterized = isinstance(t_var.bound, - tp.ParameterizedType) - is_parameterized2 = isinstance(t_arg1, - tp.ParameterizedType) + is_parameterized = t_var.bound.is_compound() + is_parameterized2 = t_arg1.is_compound() if is_parameterized and is_parameterized2: res = unify_types(t_arg1, t_var.bound, factory) if not res or any( diff --git a/src/ir/types.py b/src/ir/types.py index 49f514c3..7041a246 100644 --- a/src/ir/types.py +++ b/src/ir/types.py @@ -5,6 +5,7 @@ from typing import List, Dict, Set from src.ir.node import Node +from src.ir.decorators import two_way_subtyping class Variance(object): @@ -66,6 +67,19 @@ def has_type_variables(self): def is_subtype(self, other: Type): raise NotImplementedError("You have to implement 'is_subtype()'") + def two_way_subtyping(self, other: Type): + """ + Overwritten when a certain type needs + two-way subtyping checks. + + Eg. when checking if a string type is a subtype + of union type 'Foo | string' we call this method + as `union-type.two_way_subtyping(string_type)` + to check from the union's side. + + """ + return False + def is_assignable(self, other: Type): """ Checks of a value of the current type is assignable to 'other' type. @@ -86,6 +100,9 @@ def is_type_var(self): def is_wildcard(self): return False + def is_compound(self): + return False + def is_parameterized(self): return False @@ -107,6 +124,10 @@ def get_supertypes(self): stack.append(supertype) return visited + def substitute_type(self, type_map, + cond=lambda t: t.has_type_variables()): + return self + def not_related(self, other: Type): return not(self.is_subtype(other) or other.is_subtype(self)) @@ -153,6 +174,7 @@ def __hash__(self): """Hash based on the Type""" return hash(str(self.__class__)) + @two_way_subtyping def is_subtype(self, other: Type) -> bool: return other == self or other in self.get_supertypes() @@ -219,6 +241,7 @@ def _check_supertypes(self): str(t_class[0].t_constructor) + " " + \ "do not have the same types" + @two_way_subtyping def is_subtype(self, other: Type) -> bool: supertypes = self.get_supertypes() # Since the subtyping relation is transitive, we must also check @@ -273,10 +296,25 @@ def get_bound_rec(self, factory): # are out of scope in the context where we use this bound. return t.to_type_variable_free(factory) + @two_way_subtyping def is_subtype(self, other): if not self.bound: return False - return self.bound == other + return self.bound.is_subtype(other) + + def substitute_type(self, type_map, + cond=lambda t: t.has_type_variables()): + t = type_map.get(self) + if t is None or cond(t): + # Perform type substitution on the bound of the current type + # variable. + if self.bound is not None: + new_bound = self.bound.substitute_type(type_map, cond) + return TypeParameter(self.name, self.variance, new_bound) + # The type parameter does not correspond to an abstract type + # so, there is nothing to substitute. + return self + return t def __eq__(self, other): return (self.__class__ == other.__class__ and @@ -302,6 +340,7 @@ def __init__(self, bound=None, variance=Invariant): self.bound = bound self.variance = variance + @two_way_subtyping def is_subtype(self, other): if isinstance(other, WildCardType): if other.bound is not None: @@ -320,11 +359,23 @@ def get_type_variables(self, factory): return self.bound.get_type_variables(factory) elif self.bound.is_type_var(): return {self.bound: {self.bound.get_bound_rec(factory)}} - elif self.bound.is_parameterized(): + elif self.bound.is_compound(): return self.bound.get_type_variables(factory) else: return {} + def substitute_type(self, type_map, + cond=lambda t: t.has_type_variables()): + if self.bound is not None: + new_bound = self.bound.substitute_type(type_map, cond) + return WildCardType(new_bound, variance=self.variance) + t = type_map.get(self) + if t is None or cond(t): + # The bound does not correspond to abstract type + # so there is nothing to substitute + return self + return t + def get_bound_rec(self): if not self.bound: return None @@ -367,42 +418,8 @@ def is_primitive(self): return False -def _get_type_substitution(etype, type_map, - cond=lambda t: t.has_type_variables()): - if etype.is_parameterized(): - return substitute_type_args(etype, type_map, cond) - if etype.is_wildcard() and etype.bound is not None: - new_bound = _get_type_substitution(etype.bound, type_map, cond) - return WildCardType(new_bound, variance=etype.variance) - t = type_map.get(etype) - if t is None or cond(t): - # Perform type substitution on the bound of the current type variable. - if etype.is_type_var() and etype.bound is not None: - new_bound = _get_type_substitution(etype.bound, type_map, cond) - return TypeParameter(etype.name, etype.variance, new_bound) - # The type parameter does not correspond to an abstract type - # so, there is nothing to substitute. - return etype - return t - - -def substitute_type_args(etype, type_map, - cond=lambda t: t.has_type_variables()): - assert etype.is_parameterized() - type_args = [] - for t_arg in etype.type_args: - type_args.append(_get_type_substitution(t_arg, type_map, cond)) - new_type_map = { - tp: type_args[i] - for i, tp in enumerate(etype.t_constructor.type_parameters) - } - type_con = perform_type_substitution( - etype.t_constructor, new_type_map, cond) - return ParameterizedType(type_con, type_args) - - def substitute_type(t, type_map): - return _get_type_substitution(t, type_map, lambda t: False) + return t.substitute_type(type_map, lambda t: False) def perform_type_substitution(etype, type_map, @@ -425,7 +442,7 @@ class X : Y>() supertypes = [] for t in etype.supertypes: if t.is_parameterized(): - supertypes.append(substitute_type_args(t, type_map)) + supertypes.append(t.substitute_type(type_map)) else: supertypes.append(t) type_params = [] @@ -471,6 +488,7 @@ def __hash__(self): def is_type_constructor(self): return True + @two_way_subtyping def is_subtype(self, other: Type): supertypes = self.get_supertypes() matched_supertype = None @@ -526,7 +544,7 @@ def _to_type_variable_free(t: Type, t_param, factory) -> Type: ) ) return WildCardType(bound, variance) - elif t.is_parameterized(): + elif t.is_compound(): return t.to_type_variable_free(factory) else: return t @@ -582,6 +600,9 @@ def __init__(self, t_constructor: TypeConstructor, # XXX revisit self.supertypes = copy(self.t_constructor.supertypes) + def is_compound(self): + return True + def is_parameterized(self): return True @@ -651,18 +672,29 @@ def get_type_variables(self, factory) -> Dict[TypeParameter, Set[Type]]: # This function actually returns a dict of the enclosing type variables # along with the set of their bounds. type_vars = defaultdict(set) - for i, t_arg in enumerate(self.type_args): - t_arg = t_arg + for t_arg in self.type_args: if t_arg.is_type_var(): type_vars[t_arg].add( t_arg.get_bound_rec(factory)) - elif t_arg.is_parameterized() or t_arg.is_wildcard(): + elif t_arg.is_compound() or t_arg.is_wildcard(): for k, v in t_arg.get_type_variables(factory).items(): type_vars[k].update(v) else: continue return type_vars + def substitute_type(self, type_map, cond=lambda t: t.has_type_variables()): + type_args = [] + for t_arg in self.type_args: + type_args.append(t_arg.substitute_type(type_map, cond)) + new_type_map = { + tp: type_args[i] + for i, tp in enumerate(self.t_constructor.type_parameters) + } + type_con = perform_type_substitution( + self.t_constructor, new_type_map, cond) + return ParameterizedType(type_con, type_args) + @property def can_infer_type_args(self): return self._can_infer_type_args @@ -695,6 +727,7 @@ def get_name(self): return "{}<{}>".format(self.name, ", ".join([t.get_name() for t in self.type_args])) + @two_way_subtyping def is_subtype(self, other: Type) -> bool: if super().is_subtype(other): return True diff --git a/src/ir/typescript_ast.py b/src/ir/typescript_ast.py new file mode 100644 index 00000000..51e989d1 --- /dev/null +++ b/src/ir/typescript_ast.py @@ -0,0 +1,27 @@ +import src.ir.ast as ast +import src.ir.types as types +import src.ir.typescript_types as tst + +class TypeAliasDeclaration(ast.Declaration): + def __init__(self, name: str, + alias: types.Type): + self.name = name + self.alias = alias + + def children(self): + return [self.alias] + + def get_type(self): + return tst.AliasType(self.alias, self.name) + + def update_children(self, children): + super().update_children(children) + self.alias = children[0] + + def is_equal(self, other): + if isinstance(other, TypeAliasDeclaration): + return (self.name == other.name and + self.alias == other.alias) + + def __str__(self): + return f'{self.name} (TypeAliasDecl<{str(self.alias)}>)' diff --git a/src/ir/typescript_types.py b/src/ir/typescript_types.py new file mode 100644 index 00000000..941efbc0 --- /dev/null +++ b/src/ir/typescript_types.py @@ -0,0 +1,866 @@ +from collections import defaultdict + +import src.ir.ast as ast +import src.ir.typescript_ast as ts_ast +import src.ir.builtins as bt +import src.ir.types as tp +import src.ir.ast as ast +import src.utils as ut +from src.ir.decorators import two_way_subtyping + + +class TypeScriptBuiltinFactory(bt.BuiltinFactory): + def __init__(self, max_union_types=10, max_types_in_union=4, + max_string_literal_types=10, max_num_literal_types=10): + self._literal_type_factory = LiteralTypeFactory( + max_string_literal_types, max_num_literal_types) + self._union_type_factory = UnionTypeFactory(max_union_types, + max_types_in_union) + + def get_language(self): + return "typescript" + + def get_builtin(self): + return TypeScriptBuiltin + + def get_void_type(self): + return VoidType() + + def get_any_type(self): + return ObjectType() + + def get_number_type(self): + return NumberType(primitive=False) + + def get_boolean_type(self): + return BooleanType(primitive=False) + + def get_char_type(self): + return StringType(primitive=False) + + def get_string_type(self): + return StringType(primitive=False) + + def get_big_integer_type(self): + return BigIntegerType(primitive=False) + + def get_array_type(self): + return ArrayType() + + def get_function_type(self, nr_parameters=0): + return FunctionType(nr_parameters) + + def get_object_type(self): + return ObjectLowercaseType() + + def get_primitive_types(self): + return [ + NumberType(primitive=False), + StringType(primitive=False), + SymbolType(primitive=False), + BooleanType(primitive=False), + BigIntegerType(primitive=False), + NullType(primitive=False), + UndefinedType(primitive=False) + ] + + def get_integer_type(self): + return NumberType(primitive=False) + + def get_byte_type(self): + return NumberType(primitive=False) + + def get_short_type(self): + return NumberType(primitive=False) + + def get_long_type(self): + return NumberType(primitive=False) + + def get_float_type(self): + return NumberType(primitive=False) + + def get_double_type(self): + return NumberType(primitive=False) + + def get_big_decimal_type(self): + return NumberType(primitive=False) + + def get_null_type(self): + return NullType(primitive=False) + + def get_non_nothing_types(self): + # Overwriting Parent method to add TS-specific types + types = super().get_non_nothing_types() + types.extend([ + self.get_null_type(), + UndefinedType(primitive=False), + ] + self._literal_type_factory.get_literal_types()) + return types + + def get_decl_candidates(self): + return [gen_type_alias_decl, ] + + def update_add_node_to_parent(self): + return { + ts_ast.TypeAliasDeclaration: add_type_alias, + } + + def get_compound_types(self, gen_object): + return [ + self._union_type_factory.get_union_type(gen_object), + ] + + def get_constant_candidates(self, constants): + """ Updates the constant candidates of the generator + with the type-constant pairs for language-specific features. + + Args: + gen_object: The generator instance + constants: The dictionary of constant candidates + at the time of the method call + Returns: + A dictionary where the keys are strings of type names and + values are functions that return the appropriate constant + node for the type. + + The constants dictionary is updated at the generator-side + with the method's returned key-value pairs. + + This method is called at src.ir.generator.get_generators() + + """ + return { + "NumberLiteralType": lambda etype: ast.IntegerConstant( + etype.literal, NumberLiteralType), + "StringLiteralType": lambda etype: ast.StringConstant( + etype.literal), + "UnionType": lambda etype: self._union_type_factory.get_union_constant( + etype, constants), + } + + +class TypeScriptBuiltin(tp.Builtin): + def __init__(self, name, primitive): + super().__init__(name) + self.primitive = primitive + + def __str__(self): + if not self.is_primitive(): + return str(self.name) + "(typescript-builtin)" + return str(self.name).lower() + "(typescript-primitive)" + + def is_primitive(self): + return self.primitive + + +class ObjectType(TypeScriptBuiltin): + def __init__(self, name="Object"): + super().__init__(name, False) + + +class ObjectLowercaseType(TypeScriptBuiltin): + def __init__(self, name="object"): + super().__init__(name, False) + self.supertypes.append(ObjectType()) + + +class VoidType(TypeScriptBuiltin): + def __init__(self, name="void"): + super().__init__(name, False) + self.supertypes.append(ObjectType()) + + +class NumberType(TypeScriptBuiltin): + def __init__(self, name="Number", primitive=False): + super().__init__(name, primitive) + self.supertypes.append(ObjectType()) + + def box_type(self): + return NumberType(self.name, primitive=False) + + def get_name(self): + if self.is_primitive: + return "number" + return super().get_name() + + +class BigIntegerType(TypeScriptBuiltin): + def __init__(self, name="BigInt", primitive=False): + super().__init__(name, primitive) + self.supertypes.append(ObjectType()) + + def is_assignable(self, other): + assignable_types= [BigIntegerType] + return self.is_subtype(other) or type(other) in assignable_types + + def box_type(self): + return BigIntegerType(self.name, primitive=False) + + def get_name(self): + if self.is_primitive: + return "bigint" + return super().get_name() + + +class BooleanType(TypeScriptBuiltin): + def __init__(self, name="Boolean", primitive=False): + super().__init__(name, primitive) + self.supertypes.append(ObjectType()) + + def box_type(self): + return BooleanType(self.name, primitive=False) + + def get_name(self): + if self.is_primitive: + return "boolean" + return super().get_name() + + +class StringType(TypeScriptBuiltin): + def __init__(self, name="String", primitive=False): + super().__init__(name, primitive) + self.supertypes.append(ObjectType()) + + def box_type(self): + return StringType(self.name, primitive=False) + + def get_name(self): + if self.is_primitive: + return "string" + return super().get_name() + + +class SymbolType(TypeScriptBuiltin): + def __init__(self, name="Symbol", primitive=False): + super().__init__(name, primitive) + self.supertypes.append(ObjectType()) + + def box_type(self): + return SymbolType(self.name, primitive=False) + + def get_name(self): + if self.is_primitive(): + return "symbol" + return super().get_name() + + +class NullType(ObjectType): + def __init__(self, name="null", primitive=False): + super().__init__(name) + self.primitive = primitive + + def box_type(self): + return NullType(self.name) + + def get_name(self): + return 'null' + + +class UndefinedType(ObjectType): + def __init__(self, name="undefined", primitive=False): + super().__init__(name) + self.primitive = primitive + + def box_type(self): + return UndefinedType(self.name) + + def get_name(self): + return 'undefined' + + +class AliasType(ObjectType): + def __init__(self, alias, name="AliasType", primitive=False): + super().__init__() + self.alias = alias + self.name = name + self.primitive = primitive + + def get_type(self): + return self.alias + + @two_way_subtyping + def is_subtype(self, other): + if isinstance(other, AliasType): + return self.alias.is_subtype(other.alias) + return self.alias.is_subtype(other) + + def box_type(self): + return AliasType(self.alias, self.name) + + def get_name(self): + return self.name + + def __eq__(self, other): + return (isinstance(other, AliasType) and + self.alias == other.alias) + + def __hash__(self): + return hash(str(self.name) + str(self.alias)) + + +class NumberLiteralType(TypeScriptBuiltin): + def __init__(self, literal, name="NumberLiteralType", primitive=False): + super().__init__(name, primitive) + self.literal = literal + self.supertypes.append(NumberType()) + + def get_literal(self): + return self.literal + + @two_way_subtyping + def is_subtype(self, other): + """ A number literal type is assignable to any + supertype of type 'number'. + + It is also assignable to other number literal types, + as long as the other type's literal is the same. + + eg. let num: number + let litA: 23 = 23 + let litB: 23 + num = litA (correct) + litB = litA (correct) + + litA is assignable to litB because their literal + is the same, 23. + + """ + if (isinstance(other, AliasType) and isinstance(other.alias, NumberLiteralType)): + other = other.alias + elif isinstance(other, AliasType): + return isinstance(other.alias, NumberType) + + return ((isinstance(other, NumberLiteralType) and + other.get_literal() == self.get_literal()) or + isinstance(other, NumberType)) + + def get_name(self): + return self.name + + def __eq__(self, other): + return (self.__class__ == other.__class__ and + self.name == other.name and + self.literal == other.literal) + + def __hash__(self): + return hash(str(self.name) + str(self.literal)) + + +class StringLiteralType(TypeScriptBuiltin): + def __init__(self, literal, name="StringLiteralType", primitive=False): + super().__init__(name, primitive) + self.literal = literal + self.supertypes.append(StringType()) + + def get_literal(self): + return '"' + self.literal + '"' + + @two_way_subtyping + def is_subtype(self, other): + """ A string literal type is assignable to any + supertype of type 'string'. + + It is also assignablde to other string literal types, + as long as the other type's literal is the same. + + eg. let str: string + let litA: "PULL" = "PULL" + let litB: "PULL" + str = litA (correct) + litB = litA (correct) + + litA is assignable to litB because their literal + is the same, "PULL". + + """ + if (isinstance(other, AliasType) and isinstance(other.alias, StringLiteralType)): + other = other.alias + elif isinstance(other, AliasType): + return isinstance(other.alias, StringType) + + return ((isinstance(other, StringLiteralType) and + other.get_literal() == self.get_literal()) or + isinstance(other, StringType)) + + def get_name(self): + return self.name + + def __eq__(self, other): + return (self.__class__ == other.__class__ and + self.name == other.name and + self.literal == other.literal) + + def __hash__(self): + return hash(str(self.name) + str(self.literal)) + + +class LiteralTypeFactory: + def __init__(self, str_limit, num_limit): + self.str_literals = [] + self.num_literals = [] + # Define max number for generated literals + self.str_limit = str_limit + self.num_limit = num_limit + + def get_literal_types(self): + sl = self.gen_string_literal() + nl = self.gen_number_literal() + return [sl, nl] + + def gen_string_literal(self): + lit = None + if (len(self.str_literals) == 0 or + (len(self.str_literals) < self.str_limit and + ut.random.bool())): + # If the limit for generated literals + # has not been surpassed, we can randomly + # generate a new one. + lit = StringLiteralType(ut.random.word().lower()) + self.str_literals.append(lit) + else: + lit = ut.random.choice(self.str_literals) + return lit + + def gen_number_literal(self): + lit = None + if (len(self.num_literals) == 0 or + (len(self.num_literals) < self.num_limit and + ut.random.bool())): + # If the limit for generated literals + # has not been surpassed, we can randomly + # generate a new one. + lit = NumberLiteralType(ut.random.integer(-100, 100)) + self.num_literals.append(lit) + else: + lit = ut.random.choice(self.num_literals) + return lit + + +class UnionType(TypeScriptBuiltin): + def __init__(self, types, name="UnionType", primitive=False): + super().__init__(name, primitive) + self.types = types + + def get_types(self): + return self.types + + def is_compound(self): + return True + + @two_way_subtyping + def is_subtype(self, other): + if isinstance(other, UnionType): + for t in self.types: + if not any(t.is_subtype(other_t) for other_t in other.types): + return False + return True + return other.name == 'Object' + + def two_way_subtyping(self, other): + return other in set(self.types) + + def substitute_type(self, type_map, + cond=lambda t: t.has_type_variables()): + new_types = [] + for t in self.types: + new_t = (t.substitute_type(type_map, cond) + if t.has_type_variables() + else t) + new_types.append(new_t) + return UnionType(new_types) + + def has_type_variables(self): + return any(t.has_type_variables() for t in self.types) + + def get_type_variables(self, factory): + # This function actually returns a dict of the enclosing type variables + # along with the set of their bounds. + type_vars = defaultdict(set) + for t in self.types: + if t.is_type_var(): + type_vars[t].add( + t.get_bound_rec(factory)) + elif t.is_compound() or t.is_wildcard(): + for k, v in t.get_type_variables(factory).items(): + type_vars[k].update(v) + else: + continue + return type_vars + + def to_variance_free(self, type_var_map=None): + new_types = [] + for t in self.types: + new_types.append(t.to_variance_free(type_var_map) + if t.is_compound() + else t) + return UnionType(new_types) + + def to_type_variable_free(self, factory): + # We translate a union type that contains + # type variables into a parameterized type that is + # type variable free. + new_types = [] + for t in self.types: + if t.is_compound(): + new_type = t.to_type_variable_free(factory) + elif t.is_type_var(): + bound = t.get_bound_rec(factory) + new_type = factory.get_any_type() if bound is None else bound + else: + new_type = t + new_types.append(new_type) + return UnionType(new_types) + + def unify_types(self, t1, factory, same_type=True): + """ + This is used in src.ir.type_utils in the function + unify_types. + + We delegate work here when the first of the two + types that are passed to that function is a union type. + + For more information on the function see the detailed + explanation at the unify_types function definition. + """ + t2 = self + type_var_map = {} + + if not t2.has_type_variables(): + return {} + + # If T1 is a union type, then get all its types. + t1_types = (t1.types if t1.is_compound() and + not t1.is_parameterized() + else [t1]) + + if not t1.is_subtype(t2): + # Get the Type Variables of T2 + t_vars = dict(t2.get_type_variables(factory)) + + # Find which types of t1 are not already in t2 + add_to_t2 = set(t1_types) - set(t2.types) + + # If T1 is a union type like 100 | number | string, + # we do not need to substitute both 100 and number + # in the type variables of T2. + # Since number is a supertype of 100, if we only + # substitute number, then we have also covered + # the subtypes of 100 too! + # Hence, we only substitute necessary types in T2 + # by ensuring that the T1 type we will subtitute + # is NOT a subtype of any other types in the T1 union type. + for t1_t in list(add_to_t2): + if any(t1_t.is_subtype(other_t1_t) + and t1_t is not other_t1_t + for other_t1_t in t1_types): + add_to_t2.remove(t1_t) + + # If T1 is a union type, and its types that we need + # to substitute in T2 are more than the type variables of T2, + # then there is no substitution that ensures T1 <: T2. + if len(add_to_t2) > len(t_vars): + return {} + + # Get bounds of type variables of T2 + bounds = [b for b in t_vars.values() if b != {None}] + + # If the type variables have no bounds, then we can just assign + # the types from T1 to any type variable. + if not bounds: + temp = list(t_vars.keys()) + for t in add_to_t2: + tv = temp.pop(0) + type_var_map[tv] = t + return type_var_map + + # Get all the possible substitutions between T1 types (add_to_t2) + # and type variables of T2. A type variable can be substituted + # with a type in T1 if the type variable has no bound or + # if the type is a subtype of a bound. + possible_substitutions = {} + for t in add_to_t2: + subs = set() + # Below remember that: + # - k is the type variable + # - v is a set containing its bounds + for k,v in t_vars.items(): + if v == {None} or any(t.is_subtype(b) for b in v): + subs.add(k) + + # If there are no possible substitutions with type variables + # for any given type in T1 (add_to_t2) then there is no + # substitution that ensures T1 <: T2. + if not subs: + return {} + possible_substitutions[t] = subs + + # Decide the order of assignments (if possible) + assignments, flag = self.assign_types_to_type_vars(possible_substitutions) + if not flag: + return {} + type_var_map.update(assignments) + + + # Instantiate any not-utilized T2 type variables with their bound (if they have one). + # If they don't have a bound instantiate them with a type from T1. + leftover_type_vars = [t for t in t2.types if t.is_type_var() and t not in type_var_map] + for type_var in leftover_type_vars: + type_var_map[type_var] = type_var.bound if type_var.bound else t1_types[0] + + return type_var_map + + def assign_types_to_type_vars(self, possible_subs): + """ + This method is a helper for the method unify_types of union types (see above) + + Args: + - possible_subs: A dict containing (types.Type: set) pairs, which represents possible + T2 type variable (value) susbstitutions with a T1 type (key). + The set values contain compatible type vars of T2. + + - t2_t_vars: The type variables of T2 + + - t1_types: The types of T1 that we want to substitute in T2 + + Returns: + - A dict of (TypeVariable: types.Type) pairs representing the substitutions in T2 + + This method is needed because we need to find the correct substitutions + of the type variables in the T2 union type, in order for all T1 types (T1 can be a union type itself) + to be substituted in T2 (if possible). + + Consider the following case: + + T1: number | string + T2: boolean | X | Y extends number + + In this case, if we first substituted X with the type number from T1, + we would have been left with the type variable Y, which is not compatible + with the type string. + + As a result we would falsely conclude that we can not unify the types T1 and T2, + when in reality, had we just substituted Y with number and X with string, + we would have been ably to correctly unify the two types. + + A naive solution would be to find all the possible substitution permutations + between T1 types and T2 type variables. + + Using our approach, after first creating the possible_subs dict, + which contains all compatible T2 type variables for each T1 type, + we first substitute the T1 type that is compatible with the FEWEST + T2 type variables at any given moment. + + Going back two the above case, here is how we tackle it with our new approach: + + T1: number | string + T2: boolean | X | Y extends number + + (1) We find the possible substitutions for each T1 type (done outside this method) + possible_subs = {number: {X, Y}, + string: {X} + } + + (2) We sort the dict based on the length of the type variable sets corresponding to each type + sorted_type_subs = [(string, {X}), (number, {X, Y})] + + (3) Now we work on the first element, the pair of string and {X}. We assign the substitution + X: string. We remove X from all possible substitutions for other T1 types and then + delete the pair with key string from our possible_subs dict. + + (4) We sort the dict again, it now looks like this: + sorted_type_subs = [(number, {Y})] + + (5) Assign the substitution Y: number and repeat the rest of step (3) + + (6) The possible_subs dict is now empty, so we return our substitution dictionary + return {X: string, Y: number} + """ + type_var_map = {} + + # Continue trying to find type variable susbstitutions until + # all T1 types are substituted in T2. + while possible_subs: + # Sort the possible_subs dict, in order to first find a substitution for the T1 + # type with the fewest compatible T2 type variables. + sorted_type_subs = sorted(list(possible_subs.items()), key=lambda x: len(x[1]), reverse=False) + + # Get the first (T1 type, T2 type variable) pair (sorted_type_subs is a tuples list) + type_to_substitute, compatible_tvars = sorted_type_subs[0] + + # If there aren't any compatible_tvars, then that means that there is no possible + # order of substitutions that ensures all types in T1 are substituted in T2. + # Hence, we return a False flag to indicate that the type unification is not possible. + # Note: at this point this happens if at previous iterations of the while loop + # we substituted all the type variables that are compatible with this specific type_to_substitute. + if not compatible_tvars: + return ({}, False) + + # Get any of the compatible type variables and substitute it with the T1 type + chosen_tvar = compatible_tvars.pop() + type_var_map[chosen_tvar] = type_to_substitute + + # Remove the substituted type variable from the possible substitutions + # of all other T1 types. + for k in list(possible_subs.keys()): + if chosen_tvar in possible_subs[k]: + possible_subs[k].remove(chosen_tvar) + + # Delete the possible substitutions of the T1 type we just substituted in T2 + del possible_subs[type_to_substitute] + + # Return the substitutions we found and a flag confirming that there is a possible + # order of substitutions that gurantees type unification. + return (type_var_map, True) + + def get_name(self): + return self.name + + def __str__(self): + return f'UnionType(TS)({" | ".join([str(t) for t in self.types])})' + + def __eq__(self, other): + return (self.__class__ == other.__class__ and + self.name == other.name and + set(self.types) == set(other.types)) + + def __hash__(self): + return hash(str(self.name) + str(self.types)) + + +class UnionTypeFactory(object): + def __init__(self, max_ut, max_in_union): + self.max_ut = max_ut + self.unions = [] + self.max_in_union = max_in_union + + def get_number_of_types(self): + return ut.random.integer(2, self.max_in_union) + + def get_types_for_union(self, gen): + num_of_types = self.get_number_of_types() + types = set() + while len(types) < num_of_types: + t = gen.select_type(exclude_native_compound_types=True) + types.add(t) + return list(types) + + def gen_union_type(self, gen): + """ Generates a union type that consists of N types + where N is a number in [2, self.max_in_union]. + + Args: + gen - Instance of Hephaestus' generator + + """ + types = self.get_types_for_union(gen) + gen_union = UnionType(types) + self.unions.append(gen_union) + return gen_union + + def get_union_type(self, gen_object): + """ Returns a previously created union type + or a newly generated at random. + + If there are previously generated union types + and they have not exceeded the limit, we make a + probabilistic choice on whether to pick one of + the already generated types or create a new one. + + """ + generated = len(self.unions) + if generated == 0: + return self.gen_union_type(gen_object) + if generated >= self.max_ut or ut.random.bool(): + union_t = ut.random.choice(self.unions) + if union_t.has_type_variables(): + # We might have selected a union type that holds a type + # variable. However, we must be careful because it might be + # no possible to use the selected union type since it uses + # a type variable that is out of context. + return self.gen_union_type(gen_object) + else: + return union_t + return self.gen_union_type(gen_object) + + def get_union_constant(self, utype, constants): + """ This method randomly chooses one of the types in a type's + union and then assigns the union a constant value that matches + the randomly selected type. + + A union type can have types like 'Object' or 'undefined' + as part of its union, which however do not have a respective + constant equivalent. + + Hence, we only consider types that we can generate a constant + from. If there is none, we revert to a bottom constant. + + TODO revisit this after implementing structural types. + + """ + type_candidates = [t for t in utype.types if t.name in constants] + if len(type_candidates) == 0: + return ast.BottomConstant(utype.types[0]) + t = ut.random.choice(type_candidates) + return constants[t.name](t) + + +class ArrayType(tp.TypeConstructor, ObjectType): + def __init__(self, name="Array"): + # In TypeScript, arrays are covariant. + super().__init__(name, [tp.TypeParameter( + "T", variance=tp.Covariant)]) + + +class FunctionType(tp.TypeConstructor): + def __init__(self, nr_type_parameters: int): + name = "Function" + str(nr_type_parameters) + + # In Typescript, type parameters are covariant as to the return type + # and contravariant as to the arguments. + + type_parameters = [ + tp.TypeParameter("A" + str(i), tp.Contravariant) + for i in range(1, nr_type_parameters + 1) + ] + [tp.TypeParameter("R", tp.Covariant)] + self.nr_type_parameters = nr_type_parameters + super().__init__(name, type_parameters) + self.supertypes.append(ObjectType()) + + +# Generator Extension + +""" The below functions are all passed as candidate +generation functions to the Hephaestus generator +in order for it to be able to work with language-specific +features of typescript. +""" + + +def gen_type_alias_decl(gen, + etype=None) -> ts_ast.TypeAliasDeclaration: + """ Generate a Type Declaration (Type Alias) + + Args: + etype: the type(s) that the type alias describes + + Returns: + An AST node that describes a type alias declaration + as defined in src.ir.typescript_ast.py + + """ + alias_type = (etype if etype else + gen.select_type()) + initial_depth = gen.depth + gen.depth += 1 + gen.depth = initial_depth + type_alias_decl = ts_ast.TypeAliasDeclaration( + name=ut.random.identifier('lower'), + alias=alias_type + ) + gen._add_node_to_parent(gen.namespace, type_alias_decl) + return type_alias_decl + + +def add_type_alias(gen, namespace, type_name, ta_decl): + gen.context._add_entity(namespace, 'types', type_name, ta_decl.get_type()) + gen.context._add_entity(namespace, 'decls', type_name, ta_decl) diff --git a/src/ir/visitors.py b/src/ir/visitors.py index 52629ab8..63ba87a2 100644 --- a/src/ir/visitors.py +++ b/src/ir/visitors.py @@ -8,7 +8,15 @@ def result(self): raise NotImplementedError('result() must be implemented') def visit(self, node): - visitors = { + visitors = self.get_visitors() + visitor = visitors.get(node.__class__) + if visitor is None: + raise Exception( + "Cannot find visitor for instance node " + str(node.__class__)) + return visitor(node) + + def get_visitors(self): + return { ast.SuperClassInstantiation: self.visit_super_instantiation, ast.ClassDeclaration: self.visit_class_decl, types.TypeParameter: self.visit_type_param, @@ -21,6 +29,7 @@ def visit(self, node): ast.FunctionReference: self.visit_func_ref, ast.BottomConstant: self.visit_bottom_constant, ast.IntegerConstant: self.visit_integer_constant, + ast.NullConstant: self.visit_null_constant, ast.RealConstant: self.visit_real_constant, ast.CharConstant: self.visit_char_constant, ast.StringConstant: self.visit_string_constant, @@ -40,11 +49,6 @@ def visit(self, node): ast.Program: self.visit_program, ast.Block: self.visit_block, } - visitor = visitors.get(node.__class__) - if visitor is None: - raise Exception( - "Cannot find visitor for instance node " + str(node.__class__)) - return visitor(node) def visit_program(self, node): raise NotImplementedError('visit_program() must be implemented') @@ -90,6 +94,9 @@ def visit_integer_constant(self, node): raise NotImplementedError( 'visit_integer_constant() must be implemented') + def visit_null_constant(self, node): + raise NotImplementedError('visit_null_constant() must be implemented') + def visit_real_constant(self, node): raise NotImplementedError('visit_real_constant() must be implemented') @@ -195,6 +202,9 @@ def visit_bottom_constant(self, node): def visit_integer_constant(self, node): return self._visit_node(node) + def visit_null_constant(self, node): + return self._visit_node(node) + def visit_real_constant(self, node): return self._visit_node(node) diff --git a/src/resources/typescript_keywords b/src/resources/typescript_keywords new file mode 100644 index 00000000..614838dc --- /dev/null +++ b/src/resources/typescript_keywords @@ -0,0 +1,11 @@ +frames +Record +undefined +delete +Delete +debugger +arguments +let +export +symbol +infer diff --git a/src/translators/kotlin.py b/src/translators/kotlin.py index beb0b1db..7ae8c002 100644 --- a/src/translators/kotlin.py +++ b/src/translators/kotlin.py @@ -1,13 +1,6 @@ from src.ir import ast, kotlin_types as kt, types as tp, type_utils as tu from src.translators.base import BaseTranslator - - -def append_to(visit): - def inner(self, node): - self._nodes_stack.append(node) - res = visit(self, node) - self._nodes_stack.pop() - return inner +from src.translators.utils import append_to class KotlinTranslator(BaseTranslator): @@ -154,12 +147,6 @@ def visit_class_decl(self, node): is_sam = tu.is_sam(self.context, cls_decl=node) class_prefix = "interface" if is_sam else node.get_class_prefix() - body = "" - if function_res: - body = " {{\n{function_res}\n{old_ident}}}".format( - function_res="\n\n".join(function_res), - old_ident=" " * old_ident - ) res = "{ident}{f}{o}{p} {n}".format( ident=" " * old_ident, @@ -169,10 +156,6 @@ def visit_class_decl(self, node): not is_sam) else "", p=class_prefix, n=node.name, - tps="<" + type_parameters_res + ">" if type_parameters_res else "", - fields="(" + ", ".join(field_res) + ")" if field_res else "", - s=": " + ", ".join(superclasses_res) if superclasses_res else "", - body=body ) if type_parameters_res: diff --git a/src/translators/typescript.py b/src/translators/typescript.py new file mode 100644 index 00000000..0fa5d93b --- /dev/null +++ b/src/translators/typescript.py @@ -0,0 +1,757 @@ +from src.ir import ast, typescript_types as tst, types +from src.transformations.base import change_namespace +from src.ir.context import get_decl +from src.translators.base import BaseTranslator +from src.translators.utils import append_to +import src.ir.typescript_ast as ts_ast + +class TypeScriptTranslator(BaseTranslator): + filename = "Main.ts" + incorrect_filename = "Incorrect.ts" + executable = "Main.js" + ident_value = " " + + def __init__(self, package=None, options={}): + super().__init__(package, options) + self._children_res = [] + self.ident = 0 + self.is_interface = False + self.is_void = False + self.is_lambda = False + self.current_class = None + self.current_function = None + self.context = None + self._namespace: tuple = ast.GLOBAL_NAMESPACE + self._nodes_stack = [None] + + def _reset_state(self): + self._children_res = [] + self.ident = 0 + self.is_interface = False + self.is_void = False + self.is_lambda = False + self.current_class = None + self.current_function = None + self.context = None + self._namespace = ast.GLOBAL_NAMESPACE + self._nodes_stack = [None] + + def get_visitors(self): + # Overwriting method of ASTVisitor class + # to add typescript-specific visitors + visitors = super().get_visitors() + visitors.update({ + ts_ast.TypeAliasDeclaration: self.visit_type_alias_decl, + }) + return visitors + + def needs_this_prefix(self, node, decl): + func_name = tst.TypeScriptBuiltinFactory().get_function_type().name[:-1] + if node.receiver is not None: + return False + + if decl is None: + return True # Function is an inherited method + + if isinstance(decl, ast.FunctionDeclaration) and decl.is_class_method(): + return True # Function is method of current class + + if (isinstance(decl, ast.FieldDeclaration) and + decl.get_type().name[:-1] == func_name): + + return True # Function is callable field + return False + + def is_in_class(self, namespace): + # Checks if node is any nth descendant of a class + for i in range(1, len(namespace)-1): + if namespace[i][0].isupper(): + return True + return False + + @staticmethod + def get_filename(): + return TypeScriptTranslator.filename + + @staticmethod + def get_incorrect_filename(): + return TypeScriptTranslator.incorrect_filename + + def get_union(self, utype): + return " | ".join([self.get_type_name(t, True) for t in utype.types]) + + def type_arg2str(self, t_arg, from_union=False): + # TypeScript does not have a Wildcard type + if not t_arg.is_wildcard(): + return self.get_type_name(t_arg, from_union) + return "unknown" + + def get_type_name(self, t, from_union=False): + t_constructor = getattr(t, 't_constructor', None) + if (isinstance(t, tst.NumberLiteralType) or + isinstance(t, tst.StringLiteralType)): + return str(t.get_literal()) + if t.name == 'UnionType': + return self.get_union(t) + if not t_constructor: + return t.get_name() + + func_name = tst.TypeScriptBuiltinFactory().get_function_type().name[:-1] + if t_constructor.name.startswith(func_name): + param_types = t.type_args[:-1] + ret_type = t.type_args[-1] + res = "({}) => {}".format( + ",".join([ + "p" + str(i) + ": " + str(self.type_arg2str(pt)) + for i, pt in enumerate(param_types) + ]), + self.type_arg2str(ret_type) + ) + if from_union: + return "(" + res + ")" + return res + + return "{}<{}>".format(t.name, ", ".join([self.type_arg2str(ta) + for ta in t.type_args])) + + def pop_children_res(self, children): + len_c = len(children) + if not len_c: + return [] + res = self._children_res[-len_c:] + self._children_res = self._children_res[:-len_c] + return res + + def visit_program(self, node): + self.context = node.context + self.class_decls = [decl for decl in node.declarations + if isinstance(decl, ast.ClassDeclaration)] + children = node.children() + for c in children: + c.accept(self) + res = '\n\n'.join(self.pop_children_res(children)) + self.program = ( + res if self.package is None + else f"module {self.package} {{\n{res}\n}}" + ) + self._reset_state() + + @append_to + def visit_block(self, node): + children = node.children() + is_void = self.is_void + self.is_void = False + is_interface = self.is_interface + is_lambda = self.is_lambda + self.is_lambda = False + self.is_interface = False + for c in children: + c.accept(self) + children_res = self.pop_children_res(children) + + if is_interface: + self.is_interface = is_interface + self.is_lambda = is_lambda + self._children_res.append("") + return + + res = "{" if not is_lambda else "" + res += "\n" + ";\n".join(children_res[:-1]) + if children_res[:-1]: + res += ";\n" + ret_keyword = "return " if node.is_func_block and not is_void else "" + if children_res: + res += " "*self.ident + ret_keyword + " " + \ + children_res[-1].strip() + ";\n" + \ + " "*self.ident + else: + res += " "*self.ident + ret_keyword.strip() + ";\n" + \ + " "*self.ident + res += "}" if not is_lambda else "" + self.is_void = is_void + self.is_interface = is_interface + self.is_lambda = is_lambda + self._children_res.append(res) + + @append_to + def visit_super_instantiation(self, node): + old_ident = self.ident + children = node.children() + for c in children: + c.accept(self) + children_res = self.pop_children_res(children) + class_type = self.get_type_name(node.class_type) + super_call = None + + if node.args is not None: + children_res = [c.strip() for c in children_res] + super_call = "super(" + ", ".join(children_res) + ")" + + res = (class_type, super_call) + + self.ident = old_ident + self._children_res.append(res) + + @append_to + @change_namespace + def visit_class_decl(self, node): + old_ident = self.ident + self.ident += 2 + prev_is_interface = self.is_interface + self.is_interface = node.is_interface() + prev_class = self.current_class + self.current_class = node + children = node.children() + for c in children: + c.accept(self) + children_res = self.pop_children_res(children) + field_res = [children_res[i] + for i, _ in enumerate(node.fields)] + len_fields = len(field_res) + field_names = [field.name for field in node.fields] + + superclasses_res = [children_res[i + len_fields] + for i, _ in enumerate(node.superclasses)] + len_supercls = len(superclasses_res) + + supertype, supercall = None, None + if len_supercls > 0: + supertype, supercall = superclasses_res[0] + + function_res = [children_res[i + len_fields + len_supercls] + for i, _ in enumerate(node.functions)] + len_functions = len(function_res) + + if node.is_abstract() and supertype is not None and ( + node.superclasses[0].args is None): + # TypeScript requires abstract classes that implement interfaces + # to either implement all methods + # or to re-write the signatures with "abstract" in front. + # This code block resolves the class declaration of the + # interface supertype that the abstract class (node) + # implements and re-visits its FunctionDeclaration children + # in order to add them to its function_res + abstract_funcs = [ + f for f in node.get_abstract_functions(self.class_decls) + if f.name not in [f.name for f in node.functions] + ] + + for func in abstract_funcs: + func.accept(self) + supertype_func_res = self.pop_children_res(abstract_funcs) + + function_res += supertype_func_res + + type_parameters_res = ", ".join( + children_res[len_fields + len_supercls + len_functions:]) + class_prefix = node.get_class_prefix() + + res = "{ident}{p} {n}".format( + ident=" " * old_ident, + p=class_prefix, + n=node.name, + ) + + if type_parameters_res: + res = "{}<{}>".format(res, type_parameters_res) + if supertype is not None: + inheritance = ( + " extends " + if node.is_interface() or node.superclasses[0].args is not None + else " implements " + ) + res += inheritance + supertype + + res += " {\n" + " "*old_ident + + # Makes constructor + if not self.is_interface: + class_ident = self.ident + self.ident += 2 + body = "{" + + if supercall is not None: + body += "\n" + " "*self.ident + supercall + stripped_fields = [field.strip() for field in field_res] + for var_name in field_names: + prefix = "\n" + self.ident * " " + body += prefix + "this." + var_name + " = " + var_name + body += "\n" + class_ident*" " + "}\n" + res += " "*class_ident + "constructor({f}) {b}".format( + f=", ".join(stripped_fields), + b=body, + ) + self.ident = class_ident + + if field_res: + res += "\n\n".join(field_res) + "\n" + if function_res: + res += "\n\n".join(function_res) + "\n" + + res += old_ident*" " + "\n}" + + self.ident = old_ident + self.is_interface = prev_is_interface + self.current_class = prev_class + self._children_res.append(res) + + @append_to + def visit_type_param(self, node): + res = node.name + if node.bound: + res += " extends " + self.get_type_name(node.bound) + self._children_res.append(res) + + @append_to + def visit_null_constant(self, node): + self._children_res.append(node.literal) + + @append_to + def visit_var_decl(self, node): + old_ident = self.ident + prefix = " " * self.ident + self.ident = 0 + children = node.children() + for c in children: + c.accept(self) + children_res = self.pop_children_res(children) + var_type = "const " if node.is_final else "let " + res = prefix + var_type + node.name + + if node.var_type is not None: + res += ": " + self.get_type_name(node.var_type) + + res += " = " + children_res[0] + self.ident = old_ident + self._children_res.append(res) + + @append_to + def visit_call_argument(self, node): + old_ident = self.ident + self.ident = 0 + children = node.children() + for c in node.children(): + c.accept(self) + self.ident = old_ident + children_res = self.pop_children_res(children) + res = children_res[0] + self._children_res.append(res) + + @append_to + def visit_field_decl(self, node): + prefix = self.ident * " " + res = prefix + node.name + ": " + self.get_type_name(node.field_type) + self._children_res.append(res) + + @append_to + def visit_param_decl(self, node): + old_ident = self.ident + self.ident = 0 + children = node.children() + for c in children: + c.accept(self) + self.ident = old_ident + + param_type = node.param_type + res = node.name + ": " + self.get_type_name(param_type) + + res = "..." + res if node.vararg else res + + if len(children): + children_res = self.pop_children_res(children) + res += ( + "" if self.current_function.body is None + else " = " + children_res[0] + ) + + self._children_res.append(res) + + @append_to + @change_namespace + def visit_func_decl(self, node): + old_ident = self.ident + self.ident += 2 + prev_function = self.current_function + self.current_function = node + is_method = node.func_type == ast.FunctionDeclaration.CLASS_METHOD + is_interface = self.is_interface + self.is_interface = False + prev_is_void = self.is_void + self.is_void = node.get_type() == tst.VoidType() + + param_len = len(node.params) + + children = node.children() + for i, c in enumerate(children): + prev_namespace = self._namespace + if i < param_len: + self._namespace = self._namespace[:-1] + c.accept(self) + self._namespace = prev_namespace + children_res = self.pop_children_res(children) + + param_res = [children_res[i] for i, _ in enumerate(node.params)] + len_params = len(node.params) + + len_type_params = len(node.type_parameters) + type_parameters_res = ", ".join( + children_res[len_params:len_type_params + len_params]) + + body_res = children_res[-1] if node.body else '' + + prefix = " " * old_ident + arrow_func = "" + + if not is_method and not self.is_in_class(self._namespace): + prefix += "function " + elif not is_method: + prefix += f"let " + arrow_func = " = " + elif is_method and not node.body and not is_interface: + prefix += "abstract " + + type_params = ( + "<" + type_parameters_res + ">" if type_parameters_res else "") + + res = prefix + node.name + arrow_func + type_params + \ + "(" + ", ".join(param_res) + ")" + + if node.ret_type: + res += ": " + self.get_type_name(node.ret_type) + if body_res and arrow_func: + res += " => \n" + body_res + elif body_res and isinstance(node.body, ast.Block): + res += " \n" + body_res + elif body_res: + body_res = ("return " + body_res.strip() + if not self.is_void + else body_res.strip()) + res += "{\n" + " " * self.ident + \ + body_res + "\n" + " " * old_ident + "}" + + self.ident = old_ident + self.current_function = prev_function + self.is_void = prev_is_void + self.is_interface = is_interface + self._children_res.append(res) + + @append_to + @change_namespace + def visit_lambda(self, node): + + old_ident = self.ident + is_expression = not isinstance(node.body, ast.Block) + self.ident = 0 if is_expression else self.ident+2 + + children = node.children() + + prev_is_void = self.is_void + self.is_void = node.get_type() == tst.VoidType() + + prev_is_lambda = self.is_lambda + self.is_lambda = True + for c in children: + c.accept(self) + self.is_lambda = True + children_res = self.pop_children_res(children) + self.ident = old_ident + param_res = [children_res[i] for i, _ in enumerate(node.params)] + body_res = children_res[-1] if node.body else '' + + if not is_expression: + body_res = "{" + body_res + " "*old_ident + "}\n" + + res = "({params}): {ret} => {body}".format( + params=", ".join(param_res), + ret=self.get_type_name(node.ret_type), + body=body_res, + ) + + self.is_void = prev_is_void + self.is_lambda = prev_is_lambda + self._children_res.append(res) + + @append_to + def visit_bottom_constant(self, node): + bottom = "(undefined as unknown)" + + if node.t: + bottom = "(" + bottom + " as {})".format( + self.get_type_name(node.t) + ) + else: + bottom = "(undefined as never)" + + res = " "*self.ident + bottom + self._children_res.append(res) + + @append_to + def visit_integer_constant(self, node): + literal = "BigInt({})".format(str(node.literal)) \ + if isinstance(node.integer_type, tst.BigIntegerType) \ + else str(node.literal) + self._children_res.append(" "*self.ident + literal) + + @append_to + def visit_real_constant(self, node): + literal = str(node.literal) + self._children_res.append(" "*self.ident + literal) + + @append_to + def visit_char_constant(self, node): + # Symbol type in TypeScript + self._children_res.append(f"'{node.literal}'") + + @append_to + def visit_string_constant(self, node): + self._children_res.append('"{}"'.format(node.literal)) + + @append_to + def visit_boolean_constant(self, node): + self._children_res.append(str(node.literal)) + + @append_to + def visit_array_expr(self, node): + if not node.length: + return self._children_res.append("[]") + old_ident = self.ident + self.ident = 0 + children = node.children() + for c in children: + c.accept(self) + children_res = self.pop_children_res(children) + self.ident = old_ident + return self._children_res.append("[{}]".format( + ", ".join(children_res))) + + @append_to + def visit_variable(self, node): + res = node.name + decl = get_decl(self.context, self._namespace, node.name) + assert decl is not None + _, decl = decl + if isinstance(decl, ast.FieldDeclaration): + res = "this." + res + self._children_res.append(" " * self.ident + res) + + @append_to + def visit_binary_op(self, node): + old_ident = self.ident + self.ident = 0 + children = node.children() + for c in children: + c.accept(self) + children_res = self.pop_children_res(children) + res = "{}(({}) {} ({}))".format( + " "*old_ident, children_res[0], node.operator, + children_res[1]) + self.ident = old_ident + self._children_res.append(res) + + def visit_logical_expr(self, node): + self.visit_binary_op(node) + + def visit_equality_expr(self, node): + self.visit_binary_op(node) + + def visit_comparison_expr(self, node): + self.visit_binary_op(node) + + def visit_arith_expr(self, node): + self.visit_binary_op(node) + + @append_to + def visit_conditional(self, node): + old_ident = self.ident + self.ident += 2 + prev_namespace = self._namespace + children = node.children() + + cond = children[0] + cond.accept(self) + true_branch = children[1] + false_branch = children[2] + + if isinstance(cond, ast.Is): + self._namespace = prev_namespace + ('true_block',) + true_branch.accept(self) + self._namespace = prev_namespace + ('false_block',) + false_branch.accept(self) + else: + true_branch.accept(self) + false_branch.accept(self) + self._namespace = prev_namespace + + children_res = self.pop_children_res(children) + res = "({}{} ? {} : {})".format( + " "*old_ident, children_res[0].strip(), + children_res[1].strip(), + children_res[2].strip()) + + self.ident = old_ident + self._children_res.append(res) + + @append_to + def visit_is(self, node): + old_ident = self.ident + self.ident = 0 + children = node.children() + for c in children: + c.accept(self) + children_res = self.pop_children_res(children) + res = "{}{} {} {}".format( + " " * old_ident, children_res[0], str(node.operator), + node.rexpr.name) + self.ident = old_ident + self._children_res.append(res) + + @append_to + def visit_new(self, node): + old_ident = self.ident + self.ident = 0 + children = node.children() + for c in children: + c.accept(self) + children_res = self.pop_children_res(children) + self.ident = old_ident + # Remove type arguments from Parameterized Type + if getattr(node.class_type, 'can_infer_type_args', None) is True: + self._children_res.append("new {}({})".format( + " " * self.ident + node.class_type.name, + ", ".join(children_res))) + else: + self._children_res.append("new {}({})".format( + " " * self.ident + self.get_type_name(node.class_type), + ", ".join(children_res))) + + @append_to + def visit_field_access(self, node): + old_ident = self.ident + self.ident = 0 + children = node.children() + for c in children: + c.accept(self) + children_res = self.pop_children_res(children) + self.ident = old_ident + receiver_expr = (children_res[0] if children_res[0] + else "this") + res = "{}.{}".format(receiver_expr, node.field) + self._children_res.append(res) + + @append_to + def visit_func_ref(self, node): + old_ident = self.ident + + self.ident = 0 + children = node.children() + for c in children: + c.accept(self) + + self.ident = old_ident + children_res = self.pop_children_res(children) + + this_prefix = children_res[0] if children_res else "" + + decl = get_decl(self.context, self._namespace, node.func) + + if decl is not None: + _, decl = decl + + if self.needs_this_prefix(node, decl): + this_prefix += "this" + # TODO Must check signatures and not names + # (for overwritten + overloaded functions) + + res = "{}{}{}".format( + " "*self.ident, + "({}).".format(this_prefix) if this_prefix else "", + node.func + ) + + self._children_res.append(res) + + @append_to + def visit_func_call(self, node): + old_ident = self.ident + self.ident = 0 + children = node.children() + for c in children: + c.accept(self) + self.ident = old_ident + children_res = self.pop_children_res(children) + type_args = "" + if not node.can_infer_type_args and node.type_args: + type_args += ( + "<" + ",".join( + [self.get_type_name(t) for t in node.type_args]) + ">" + ) + + this_prefix = "" + + decl = get_decl(self.context, self._namespace, node.name) + + if decl is not None: + _, decl = decl + + if self.needs_this_prefix(node, decl): + this_prefix += "this." + # FIXME Must check signatures and not names + # (for overwritten + overloaded functions) + + if node.receiver: + receiver_expr = children_res[0] + res = "{}{}.{}{}({})".format( + " " * self.ident, receiver_expr, node.func, + type_args, + ", ".join(children_res[1:])) + else: + res = "{}{}{}{}({})".format( + " " * self.ident, + this_prefix, node.func, type_args, + ", ".join(children_res)) + self._children_res.append(res) + + @append_to + def visit_assign(self, node): + old_ident = self.ident + self.ident = 0 + children = node.children() + for c in children: + c.accept(self) + self.ident = old_ident + children_res = self.pop_children_res(children) + + decl = get_decl(self.context, self._namespace, node.name) + + if decl is not None: + _, decl = decl + + if node.receiver or decl is None or isinstance( + decl, ast.FieldDeclaration): + receiver_expr = (children_res[0] if node.receiver else "this") + expr = children_res[1] if node.receiver else children_res[0] + res = "{}{}.{} = {}".format( + " " * old_ident, + receiver_expr, + node.name, + expr + ) + else: + res = "{}{} = {}".format( + " " * old_ident, + node.name, + children_res[0] + ) + + self.ident = old_ident + self._children_res.append(res) + + @append_to + def visit_type_alias_decl(self, node): + old_ident = self.ident + prefix = " " * self.ident + self.ident = 0 + res = prefix + "type " + node.name + res += " = " + self.get_type_name(node.alias) + self.ident = old_ident + self._children_res.append(res) diff --git a/src/translators/utils.py b/src/translators/utils.py new file mode 100644 index 00000000..f7c9838c --- /dev/null +++ b/src/translators/utils.py @@ -0,0 +1,6 @@ +def append_to(visit): + def inner(self, node): + self._nodes_stack.append(node) + visit(self, node) + self._nodes_stack.pop() + return inner diff --git a/src/utils.py b/src/utils.py index bb5f4573..3b7a4ddb 100644 --- a/src/utils.py +++ b/src/utils.py @@ -167,6 +167,28 @@ def caps(self, length=1, blacklist=None): def range(self, from_value, to_value): return range(0, self.integer(from_value, to_value)) + def identifier(self, ident_type:str=None) -> str: + """Generate an identifier name. + + Args: + ident_type: None or 'capitalize' or 'lower' + + Raises: + AssertionError: Raises an AssertionError if the ident_type is neither + 'capitalize' nor 'lower'. + """ + word = self.word() + if ident_type is None: + return word + if ident_type == 'lower': + return word.lower() + if ident_type == 'capitalize': + return word.capitalize() + raise AssertionError("ident_type should be 'capitalize' or 'lower'") + + def shuffle(self, ll): + return self.r.shuffle(ll) + random = RandomUtils() diff --git a/tests/resources/program1.py b/tests/resources/program1.py index 475a1ac7..b9a2b99e 100644 --- a/tests/resources/program1.py +++ b/tests/resources/program1.py @@ -118,4 +118,4 @@ ctx.add_func(GLOBAL_NAMESPACE + ('A',), foo_func.name, foo_func) ctx.add_func(GLOBAL_NAMESPACE + ('A',), buz_func.name, buz_func) ctx.add_func(GLOBAL_NAMESPACE + ('A',), spam_func.name, spam_func) -program = Program(ctx, language="kotlin") +program = Program(ctx, "kotlin", KotlinBuiltinFactory()) diff --git a/tests/resources/program10.py b/tests/resources/program10.py index 591eea19..789541b9 100644 --- a/tests/resources/program10.py +++ b/tests/resources/program10.py @@ -126,4 +126,4 @@ ctx.add_var(ast.GLOBAL_NAMESPACE + ("Third",), third_z.name, third_z) ctx.add_func(ast.GLOBAL_NAMESPACE + ("Third",), third_foo.name, third_foo) ctx.add_var(ast.GLOBAL_NAMESPACE + ("Third", "foo"), third_foo_k.name, third_foo_k) -program = ast.Program(ctx, language="kotlin") +program = ast.Program(ctx, "kotlin", kt.KotlinBuiltinFactory()) diff --git a/tests/resources/program11.py b/tests/resources/program11.py index c5c97478..2a39a327 100644 --- a/tests/resources/program11.py +++ b/tests/resources/program11.py @@ -87,4 +87,4 @@ ctx.add_var(ast.GLOBAL_NAMESPACE + ("foo",), foo_x.name, foo_x) ctx.add_func(ast.GLOBAL_NAMESPACE, "bar", fun_bar) ctx.add_var(ast.GLOBAL_NAMESPACE + ("bar",), bar_y.name, bar_y) -program = ast.Program(ctx, language="kotlin") +program = ast.Program(ctx, "kotlin", kt.KotlinBuiltinFactory()) diff --git a/tests/resources/program12.py b/tests/resources/program12.py index 18a3faba..7f6a98b0 100644 --- a/tests/resources/program12.py +++ b/tests/resources/program12.py @@ -87,4 +87,4 @@ ctx.add_func(ast.GLOBAL_NAMESPACE, "foo", fun_foo) ctx.add_var(ast.GLOBAL_NAMESPACE + ("foo",), foo_x.name, foo_x) ctx.add_func(ast.GLOBAL_NAMESPACE, "bar", fun_bar) -program = ast.Program(ctx, language="kotlin") +program = ast.Program(ctx, "kotlin", kt.KotlinBuiltinFactory()) diff --git a/tests/resources/program2.py b/tests/resources/program2.py index fb01532e..fb19e916 100644 --- a/tests/resources/program2.py +++ b/tests/resources/program2.py @@ -44,4 +44,4 @@ ctx.add_var(GLOBAL_NAMESPACE + ('Bam',), 'x', xB_field) ctx.add_var(GLOBAL_NAMESPACE + ('Bam', 'getX'), 'z', z_get_param) ctx.add_func(GLOBAL_NAMESPACE + ('Bam',), getX_func.name, getX_func) -program = Program(ctx, language="kotlin") +program = Program(ctx, "kotlin", KotlinBuiltinFactory()) diff --git a/tests/resources/program3.py b/tests/resources/program3.py index 3f6728c4..25b26fb8 100644 --- a/tests/resources/program3.py +++ b/tests/resources/program3.py @@ -73,4 +73,4 @@ ctx.add_var(ast.GLOBAL_NAMESPACE + ("A", "bar"), bar_x.name, bar_x) ctx.add_var(ast.GLOBAL_NAMESPACE + ("A", "bar"), bar_z.name, bar_z) ctx.add_var(ast.GLOBAL_NAMESPACE + ("A", "bar"), bar_y.name, bar_y) -program = ast.Program(ctx, language="kotlin") +program = ast.Program(ctx, "kotlin", kt.KotlinBuiltinFactory()) diff --git a/tests/resources/program4.py b/tests/resources/program4.py index 672bc09d..a4e97f4e 100644 --- a/tests/resources/program4.py +++ b/tests/resources/program4.py @@ -38,4 +38,4 @@ ctx.add_func(ast.GLOBAL_NAMESPACE + ("A",), fun_foo.name, fun_foo) ctx.add_var(ast.GLOBAL_NAMESPACE + ("A",), field_x.name, field_x) ctx.add_func(ast.GLOBAL_NAMESPACE + ("A", "foo"), fun_bar.name, fun_bar) -program = ast.Program(ctx, language="kotlin") +program = ast.Program(ctx, "kotlin", kt.KotlinBuiltinFactory()) diff --git a/tests/resources/program5.py b/tests/resources/program5.py index 0237127b..b461fa76 100644 --- a/tests/resources/program5.py +++ b/tests/resources/program5.py @@ -82,4 +82,4 @@ ctx.add_var(ast.GLOBAL_NAMESPACE + ("A", "bar"), bar_z.name, bar_z) ctx.add_var(ast.GLOBAL_NAMESPACE + ("A", "quz"), quz_y.name, quz_y) -program = ast.Program(ctx, language="kotlin") +program = ast.Program(ctx, "kotlin", kt.KotlinBuiltinFactory()) diff --git a/tests/resources/program6.py b/tests/resources/program6.py index 2a171fb5..1bfba86a 100644 --- a/tests/resources/program6.py +++ b/tests/resources/program6.py @@ -59,4 +59,4 @@ ctx.add_func(ast.GLOBAL_NAMESPACE + ("A",), fun_baz.name, fun_baz) ctx.add_var(ast.GLOBAL_NAMESPACE + ("A",), field_x.name, field_x) ctx.add_var(ast.GLOBAL_NAMESPACE + ("A", "bar"), bar_y.name, bar_y) -program = ast.Program(ctx, language="kotlin") +program = ast.Program(ctx, "kotlin", kt.KotlinBuiltinFactory()) diff --git a/tests/resources/program7.py b/tests/resources/program7.py index e300fb6d..28c61b2d 100644 --- a/tests/resources/program7.py +++ b/tests/resources/program7.py @@ -42,4 +42,4 @@ ctx.add_func(ast.GLOBAL_NAMESPACE + ("A",), fun_bar.name, fun_bar) ctx.add_func(ast.GLOBAL_NAMESPACE + ("A",), fun_foo.name, fun_foo) ctx.add_var(ast.GLOBAL_NAMESPACE + ("A", "foo",), foo_x.name, foo_x) -program = ast.Program(ctx, language="kotlin") +program = ast.Program(ctx, "kotlin", kt.KotlinBuiltinFactory()) diff --git a/tests/resources/program8.py b/tests/resources/program8.py index 45dea83d..83de0a3b 100644 --- a/tests/resources/program8.py +++ b/tests/resources/program8.py @@ -33,4 +33,4 @@ ctx.add_class(ast.GLOBAL_NAMESPACE, "A", cls) ctx.add_func(ast.GLOBAL_NAMESPACE + ("A",), fun_foo.name, fun_foo) ctx.add_var(ast.GLOBAL_NAMESPACE + ("A", "foo",), foo_x.name, foo_x) -program = ast.Program(ctx, language="kotlin") +program = ast.Program(ctx, "kotlin", kt.KotlinBuiltinFactory()) diff --git a/tests/resources/program9.py b/tests/resources/program9.py index 4ed2eb8c..547f39b9 100644 --- a/tests/resources/program9.py +++ b/tests/resources/program9.py @@ -63,4 +63,4 @@ ctx.add_var(ast.GLOBAL_NAMESPACE + ("foo",), foo_lank.name, foo_lank) ctx.add_func(ast.GLOBAL_NAMESPACE, "bar", fun_bar) ctx.add_var(ast.GLOBAL_NAMESPACE + ("bar",), bar_cinches.name, bar_cinches) -program = ast.Program(ctx, language="kotlin") +program = ast.Program(ctx, "kotlin", kt.KotlinBuiltinFactory()) diff --git a/tests/resources/type_analysis_programs.py b/tests/resources/type_analysis_programs.py index 65df429b..31673be1 100644 --- a/tests/resources/type_analysis_programs.py +++ b/tests/resources/type_analysis_programs.py @@ -12,7 +12,7 @@ context = ctx.Context() context.add_var(ast.GLOBAL_NAMESPACE, var_decl.name, var_decl) context.add_class(ast.GLOBAL_NAMESPACE, cls.name, cls) -program1 = ast.Program(context, "kotlin") +program1 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program2 @@ -28,7 +28,7 @@ context.add_var(ast.GLOBAL_NAMESPACE, var_x.name, var_x) context.add_var(ast.GLOBAL_NAMESPACE, var_y.name, var_y) context.add_class(ast.GLOBAL_NAMESPACE, cls.name, cls) -program2 = ast.Program(context, "kotlin") +program2 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program3 @@ -41,7 +41,7 @@ context.add_var(ast.GLOBAL_NAMESPACE, var_y.name, var_y2) context.add_class(ast.GLOBAL_NAMESPACE, cls.name, cls) context.add_class(ast.GLOBAL_NAMESPACE, cls2.name, cls2) -program3 = ast.Program(context, "kotlin") +program3 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program4 @@ -62,7 +62,7 @@ context.add_class(ast.GLOBAL_NAMESPACE, cls1.name, cls1) context.add_class(ast.GLOBAL_NAMESPACE, cls2.name, cls2) context.add_class(ast.GLOBAL_NAMESPACE, cls3.name, cls3) -program4 = ast.Program(context, "kotlin") +program4 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program5 @@ -85,7 +85,7 @@ context.add_var(ast.GLOBAL_NAMESPACE, var2.name, var2) context.add_var(FUNC_NAMESPACE, param1.name, param1) context.add_var(FUNC_NAMESPACE, var1.name, var1) -program5 = ast.Program(context, "kotlin") +program5 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program6 @@ -105,7 +105,7 @@ context.add_var(ast.GLOBAL_NAMESPACE, var.name, var) context.add_func(ast.GLOBAL_NAMESPACE + (cls1.name,), func.name, func) context.add_var(FUNC_NAMESPACE, param1.name, param1) -program6 = ast.Program(context, "kotlin") +program6 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program7 @@ -119,7 +119,7 @@ context.add_class(ast.GLOBAL_NAMESPACE, cls1.name, cls1) context.add_var(ast.GLOBAL_NAMESPACE, var1.name, var1) context.add_var(ast.GLOBAL_NAMESPACE, var2.name, var2) -program7 = ast.Program(context, "kotlin") +program7 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program 8 @@ -135,7 +135,7 @@ context.add_func(ast.GLOBAL_NAMESPACE, func.name, func) context.add_class(ast.GLOBAL_NAMESPACE, cls1.name, cls1) context.add_var(FUNC_NAMESPACE, var1.name, var1) -program8 = ast.Program(context, "kotlin") +program8 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program9 @@ -158,7 +158,7 @@ context.add_var(ast.GLOBAL_NAMESPACE + (cls2.name,), f.name, f) context.add_func(ast.GLOBAL_NAMESPACE, func.name, func) context.add_var(FUNC_NAMESPACE, var1.name, var1) -program9 = ast.Program(context, "kotlin") +program9 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program10 @@ -176,7 +176,7 @@ context.add_class(ast.GLOBAL_NAMESPACE, cls1.name, cls1) context.add_var(ast.GLOBAL_NAMESPACE + (cls1.name,), f.name, f) context.add_func(ast.GLOBAL_NAMESPACE, func.name, func) -program10 = ast.Program(context, "kotlin") +program10 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program 11 @@ -191,7 +191,7 @@ context.add_class(ast.GLOBAL_NAMESPACE, cls1.name, cls1) context.add_var(ast.GLOBAL_NAMESPACE + (cls1.name,), f.name, f) context.add_func(ast.GLOBAL_NAMESPACE, func.name, func) -program11 = ast.Program(context, "kotlin") +program11 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program 12 @@ -217,7 +217,7 @@ context.add_func(ast.GLOBAL_NAMESPACE, func.name, func) context.add_var(ast.GLOBAL_NAMESPACE + (cls1.name,), f.name, f) context.add_var(FUNC_NAMESPACE, var1.name, var1) -program12 = ast.Program(context, "kotlin") +program12 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program 13 @@ -251,7 +251,7 @@ context.add_var(FUNC_NAMESPACE, var1.name, var1) context.add_var(FUNC_NAMESPACE, var2.name, var2) context.add_var(FUNC_NAMESPACE, var3.name, var3) -program13 = ast.Program(context, "kotlin") +program13 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program 14 @@ -265,7 +265,7 @@ context = ctx.Context() context.add_func(ast.GLOBAL_NAMESPACE, func.name, func) context.add_var(FUNC_NAMESPACE, var1.name, var1) -program14 = ast.Program(context, "kotlin") +program14 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program15 @@ -280,7 +280,7 @@ context.add_class(ast.GLOBAL_NAMESPACE, cls1.name, cls1) context.add_var(ast.GLOBAL_NAMESPACE + (cls1.name,), f.name, f) context.add_var(ast.GLOBAL_NAMESPACE, var.name, var) -program15 = ast.Program(context, "kotlin") +program15 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program16 @@ -298,7 +298,7 @@ context.add_func(ast.GLOBAL_NAMESPACE, func.name, func) context.add_func(ast.GLOBAL_NAMESPACE, func2.name, func2) context.add_var(ast.GLOBAL_NAMESPACE + (func.name,), param1.name, param1) -program16 = ast.Program(context, "kotlin") +program16 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program 17 @@ -315,7 +315,7 @@ context.add_func(ast.GLOBAL_NAMESPACE, func.name, func) context.add_func(ast.GLOBAL_NAMESPACE, func2.name, func2) context.add_var(ast.GLOBAL_NAMESPACE + (func.name,), param1.name, param1) -program17 = ast.Program(context, "kotlin") +program17 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program 18 @@ -336,7 +336,7 @@ context.add_func(ast.GLOBAL_NAMESPACE, func.name, func) context.add_func(ast.GLOBAL_NAMESPACE, func2.name, func2) context.add_var(ast.GLOBAL_NAMESPACE + (func.name,), param1.name, param1) -program18 = ast.Program(context, "kotlin") +program18 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program 19 @@ -359,7 +359,7 @@ context.add_class(ast.GLOBAL_NAMESPACE, cls3.name, cls3) context.add_var(ast.GLOBAL_NAMESPACE + (cls3.name,), f.name, f) context.add_var(ast.GLOBAL_NAMESPACE, var1.name, var1) -program19 = ast.Program(context, "kotlin") +program19 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program 20 @@ -377,7 +377,7 @@ context.add_class(ast.GLOBAL_NAMESPACE, cls2.name, cls2) context.add_var(ast.GLOBAL_NAMESPACE + (cls2.name,), cls2.name, cls2) context.add_var(ast.GLOBAL_NAMESPACE, var1.name, var1) -program20 = ast.Program(context, "kotlin") +program20 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program 21 @@ -396,7 +396,7 @@ context.add_class(ast.GLOBAL_NAMESPACE, cls3.name, cls3) context.add_var(ast.GLOBAL_NAMESPACE + (cls3.name,), cls3.name, cls3) context.add_var(ast.GLOBAL_NAMESPACE, var1.name, var1) -program21 = ast.Program(context, "kotlin") +program21 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program 22 @@ -415,7 +415,7 @@ context.add_class(ast.GLOBAL_NAMESPACE, cls2.name, cls2) context.add_var(ast.GLOBAL_NAMESPACE + (cls2.name,), f.name, f) context.add_var(ast.GLOBAL_NAMESPACE, var1.name, var1) -program22 = ast.Program(context, "kotlin") +program22 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program 23 @@ -429,7 +429,7 @@ context.add_class(ast.GLOBAL_NAMESPACE, cls1.name, cls1) context.add_func(ast.GLOBAL_NAMESPACE, func.name, func) context.add_var(ast.GLOBAL_NAMESPACE, var1.name, var1) -program23 = ast.Program(context, "kotlin") +program23 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program 24 @@ -444,7 +444,7 @@ context.add_class(ast.GLOBAL_NAMESPACE, cls1.name, cls1) context.add_func(ast.GLOBAL_NAMESPACE, func.name, func) context.add_var(ast.GLOBAL_NAMESPACE, var1.name, var1) -program24 = ast.Program(context, "kotlin") +program24 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program 25 @@ -459,7 +459,7 @@ context.add_class(ast.GLOBAL_NAMESPACE, cls1.name, cls1) context.add_func(ast.GLOBAL_NAMESPACE, func.name, func) context.add_var(ast.GLOBAL_NAMESPACE, var1.name, var1) -program25 = ast.Program(context, "kotlin") +program25 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program 26 @@ -481,7 +481,7 @@ context.add_class(ast.GLOBAL_NAMESPACE, cls2.name, cls2) context.add_func(ast.GLOBAL_NAMESPACE, func.name, func) context.add_var(ast.GLOBAL_NAMESPACE, var1.name, var1) -program26 = ast.Program(context, "kotlin") +program26 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program 27 @@ -505,7 +505,7 @@ context.add_class(ast.GLOBAL_NAMESPACE, cls2.name, cls2) context.add_func(ast.GLOBAL_NAMESPACE, func.name, func) context.add_var(ast.GLOBAL_NAMESPACE, var1.name, var1) -program27 = ast.Program(context, "kotlin") +program27 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program 28 @@ -531,7 +531,7 @@ context.add_var(ast.GLOBAL_NAMESPACE + (cls2.name,), f.name, f) context.add_func(ast.GLOBAL_NAMESPACE + (cls3.name,), func.name, func) context.add_var(ast.GLOBAL_NAMESPACE + (cls3.name, func.name), var1.name, var1) -program28 = ast.Program(context, "kotlin") +program28 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) # program 29 @@ -549,4 +549,4 @@ context.add_class(ast.GLOBAL_NAMESPACE, cls2.name, cls2) context.add_var(ast.GLOBAL_NAMESPACE + (cls2.name,), f.name, f) context.add_var(ast.GLOBAL_NAMESPACE, var1.name, var1) -program29 = ast.Program(context, "kotlin") +program29 = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) diff --git a/tests/test_type_utils.py b/tests/test_type_utils.py index b3a53593..552225d5 100644 --- a/tests/test_type_utils.py +++ b/tests/test_type_utils.py @@ -842,7 +842,7 @@ def test_type_hint_field_access_inheritance(): expr = ast.FieldAccess(cond3, "f") - program = ast.Program(context, language="kotlin") + program = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) types = program.get_types() assert tutils.get_type_hint(expr, context, ast.GLOBAL_NAMESPACE, KT_FACTORY, types) == kt.Integer @@ -875,7 +875,7 @@ def test_type_hint_smart_cast(): cond2 = ast.Conditional(ast.Is(expr2, cls.get_type()), expr2, expr3, cls.get_type()) smart_casts = [(expr, cls.get_type())] - program = ast.Program(context, language="kotlin") + program = ast.Program(context, "kotlin", kt.KotlinBuiltinFactory()) types = program.get_types() assert tutils.get_type_hint(expr, context, ast.GLOBAL_NAMESPACE, diff --git a/tests/test_typescript.py b/tests/test_typescript.py new file mode 100644 index 00000000..a1889d84 --- /dev/null +++ b/tests/test_typescript.py @@ -0,0 +1,258 @@ +from src.ir.builtins import NumberType +import src.ir.typescript_types as tst +import src.ir.typescript_ast as ts_ast +import src.ir.types as tp +import src.ir.type_utils as tu + + +def test_type_alias_with_literals(): + # Tests subtyping relations between a string alias and string literal + # and between a number alias and a number literal. + # - (type Foo = string) with literal "foo" + # - (type Bar = number) with literal 5 + string_alias = ts_ast.TypeAliasDeclaration("Foo", tst.StringType()).get_type() + number_alias = ts_ast.TypeAliasDeclaration("Bar", tst.NumberType()).get_type() + + string_lit = tst.StringLiteralType("foo") + number_lit = tst.NumberLiteralType(5) + + assert string_lit.is_subtype(string_alias) + assert not string_alias.is_subtype(string_lit) + assert number_lit.is_subtype(number_alias) + assert not number_alias.is_subtype(number_lit) + + +def test_type_alias_with_literals2(): + # Tests subtyping relation between a literal alias + # and their corresponding literal type. + # - (type Foo = "foo") with literal "foo" + # - (type Bar = "bar") with literal "bar" + string_alias = ts_ast.TypeAliasDeclaration("Foo", tst.StringLiteralType("foo")).get_type() + number_alias = ts_ast.TypeAliasDeclaration("Bar", tst.NumberLiteralType(5)).get_type() + + string_lit = tst.StringLiteralType("foo") + number_lit = tst.NumberLiteralType(5) + + assert string_lit.is_subtype(string_alias) + assert number_lit.is_subtype(number_alias) + assert string_alias.is_subtype(string_lit) + assert number_alias.is_subtype(number_lit) + + +def test_union_types_simple(): + # Tests subtyping relation between union types + # and the types in their union. + # - number | boolean + # - boolean | "bar" + # - boolean | number + union_1 = tst.UnionType([tst.NumberType(), tst.BooleanType()]) + + bar_lit = tst.StringLiteralType("bar") + union_2 = tst.UnionType([tst.BooleanType(), bar_lit]) + + union_3 = tst.UnionType([tst.BooleanType(), tst.NumberType()]) + + assert not union_1.is_subtype(union_2) + assert not union_2.is_subtype(union_1) + assert union_3.is_subtype(union_1) + assert union_1.is_subtype(union_3) + + +def test_union_types_other_types(): + # Tests that types A, B are subtypes of A | B + union = tst.UnionType([tst.NumberType(), tst.BooleanType()]) + assert tst.NumberType().is_subtype(union) + assert tst.BooleanType().is_subtype(union) + + +def test_union_type_assign(): + # Tests correct creation and assignment of union type + union = tst.UnionType([tst.StringType(), tst.NumberType(), tst.BooleanType(), tst.ObjectType()]) + foo = tst.StringType() + + assert len(union.types) == 4 + assert not union.is_subtype(foo) + assert foo.is_subtype(union) + + +def test_union_type_param(): + # Tests that union type bounds of type parameters do not + # conflict with the sybtyping relations between the two. + union1 = tst.UnionType([tst.NumberType(), tst.NullType()]) + union2 = tst.UnionType([tst.StringLiteralType("foo"), tst.NumberType()]) + t_param = tp.TypeParameter("T", bound=union2) + + assert not union2.is_subtype(union1) + assert not union1.is_subtype(t_param) + assert not t_param.is_subtype(union1) + + +def test_union_type_substitution(): + # Tests substitution of type parametes in union types + type_param1 = tp.TypeParameter("T1") + type_param2 = tp.TypeParameter("T2") + type_param3 = tp.TypeParameter("T3") + type_param4 = tp.TypeParameter("T4") + + foo = tp.TypeConstructor("Foo", [type_param1, type_param2]) + foo_p = foo.new([tst.NumberType(), type_param3]) + + union = tst.UnionType([tst.StringLiteralType("bar"), foo_p]) + ptype = tp.substitute_type(union, {type_param3: type_param4}) + + assert ptype.types[1].type_args[0] == tst.NumberType() + assert ptype.types[1].type_args[1] == type_param4 + + +def test_union_type_substitution_type_var_bound(): + # Tests substitution of bounded type parameters in union types + type_param1 = tp.TypeParameter("T1") + type_param2 = tp.TypeParameter("T2", bound=type_param1) + type_map = {type_param1: tst.StringType()} + + union = tst.UnionType([tst.NumberType(), type_param2]) + ptype_union = tp.substitute_type(union, type_map) + ptype = ptype_union.types[1] + + + assert ptype.name == type_param2.name + assert ptype.variance == type_param2.variance + assert ptype.bound == tst.StringType() + + +def test_union_to_type_variable_free(): + # Tests the builtin method to-type-variable-free of union types + type_param1 = tp.TypeParameter("T1") + type_param2 = tp.TypeParameter("T2") + foo = tp.TypeConstructor("Foo", [type_param1]) + foo_t = foo.new([type_param2]) + union = tst.UnionType([foo_t, tst.StringLiteralType("bar")]) + + union_n = union.to_type_variable_free(tst.TypeScriptBuiltinFactory()) + foo_n = union_n.types[0] + assert foo_n.type_args[0] == tp.WildCardType(tst.ObjectType(), variance=tp.Covariant) + + type_param2.bound = tst.NumberType() + foo_t = foo.new([type_param2]) + union = tst.UnionType([foo_t, tst.NumberLiteralType(43)]) + + union_n = union.to_type_variable_free(tst.TypeScriptBuiltinFactory()) + foo_n = union_n.types[0] + assert foo_n.type_args[0] == tp.WildCardType(tst.NumberType(), variance=tp.Covariant) + + bar = tp.TypeConstructor("Bar", [tp.TypeParameter("T")]) + bar_p = bar.new([type_param2]) + foo_t = foo.new([bar_p]) + union = tst.UnionType([foo_t, tst.NumberType(), tst.StringType(), tst.AliasType(tst.StringLiteralType("foobar"))]) + + union_n = union.to_type_variable_free(tst.TypeScriptBuiltinFactory()) + foo_n = union_n.types[0] + assert foo_n.type_args[0] == bar.new( + [tp.WildCardType(tst.NumberType(), variance=tp.Covariant)]) + + +def test_union_type_unification_type_var(): + union = tst.UnionType([tst.StringType(), tst.StringLiteralType("foo")]) + type_param = tp.TypeParameter("T") + + # Case 1: Unify a union with an unbounded type param + type_var_map = tu.unify_types(union, type_param, tst.TypeScriptBuiltinFactory()) + assert len(type_var_map) == 1 + assert type_var_map == {type_param: union} + + # Case 2: unify a union with a bounded type param, which has an + # incompatible bound with the given union. + union = tst.UnionType([tst.NumberType(), tst.StringType()]) + type_param = tp.TypeParameter("T", bound=tst.NumberType()) + + type_var_map = tu.unify_types(union, type_param, + tst.TypeScriptBuiltinFactory()) + assert type_var_map == {} + + + # Case 3: unify a union with a bounded type param, which has a compatible + # bound with the given union. + type_param = tp.TypeParameter("T", bound=union) + type_var_map = tu.unify_types(union, type_param, + tst.TypeScriptBuiltinFactory()) + assert type_var_map == {type_param: union} + +def test_union_type_unification(): + type_param = tp.TypeParameter("T") + union1 = tst.UnionType([tst.NumberLiteralType(1410), tst.NumberType(), tst.StringType()]) + union2 = tst.UnionType([type_param, tst.NumberType(), tst.StringType()]) + assert union1.is_subtype(union2) + + # Unify t1: 1410 | number | string + # with t2: T | number | string + # Result should be: {T: 1410} + type_var_map = tu.unify_types(union1, union2, tst.TypeScriptBuiltinFactory()) + assert len(type_var_map) == 1 + assert type_var_map == {type_param: union1.types[0]} + + type_param2 = tp.TypeParameter("G") + union3 = tst.UnionType([type_param, type_param2, tst.StringLiteralType("foo")]) + + # Unify t1: 1410 | number | string + # with t3: T | G | "foo". + # Result should be: {T: number, G: string} or reversed. + type_var_map = tu.unify_types(union1, union3, tst.TypeScriptBuiltinFactory()) + assert len(type_var_map) == 2 + assert type_param, type_param2 in type_var_map + assert union1.types[1], union1.types[2] in type_var_map.values() + + +def test_union_type_unification2(): + union = tst.UnionType([tst.NumberType(), tst.StringType()]) + assert tu.unify_types(tst.BooleanType(), union, tst.TypeScriptBuiltinFactory()) == {} + + # Unify t1: number + # with t2: number | T + # Result should be: {T: number} + t1 = tst.NumberType() + t2 = tst.UnionType([tst.NumberType(), tp.TypeParameter("T")]) + res = tu.unify_types(t1, t2, tst.TypeScriptBuiltinFactory()) + assert len(res) == 1 and res[t2.types[1]] == t1 + + # Unify t1: number | string + # with t2: number | T + # Result should be: {T: string} + t1 = tst.UnionType([tst.NumberType(), tst.StringType()]) + res = tu.unify_types(t1, t2, tst.TypeScriptBuiltinFactory()) + assert len(res) == 1 and res[t2.types[1]] == t1.types[1] + + # Unify t1: number | "foo" | string + # with t2: number | T + # Result should be: {T: string} + t1 = tst.UnionType([tst.NumberType(), tst.StringLiteralType("foo"), tst.StringType()]) + res = tu.unify_types(t1, t2, tst.TypeScriptBuiltinFactory()) + assert len(res) == 1 and res[t2.types[1]] == t1.types[2] + + t1 = tst.UnionType([tst.NumberType(), tst.NumberLiteralType(100), tst.BooleanType(), tst.StringLiteralType("foo"), tst.StringType()]) + + t_param1 = tp.TypeParameter("T", bound=tst.StringType()) + helper_union = tst.UnionType([tst.BooleanType(), tst.StringType()]) + t_param2 = tp.TypeParameter("G", bound=helper_union) + t2 = tst.UnionType([tst.NumberType(), t_param1, t_param2]) + + # Unify t1: number | 100 | boolean | "foo" | string + # with t2: number | T extends string | G extends (boolean | string) + # Result should be {T: StringType, G: BooleanType} + res = tu.unify_types(t1, t2, tst.TypeScriptBuiltinFactory()) + assert (len(res) == 2 and + res[t2.types[1]] == t1.types[4] and + res[t2.types[2]] == t1.types[2]) + + +def test_union_to_type_variable_free(): + type_param = tp.TypeParameter("S") + union = tst.UnionType([tst.NumberType(), type_param]) + + new_union = union.to_type_variable_free(tst.TypeScriptBuiltinFactory()) + assert new_union == tst.UnionType([tst.NumberType(), + tst.TypeScriptBuiltinFactory().get_any_type()]) + + type_param = tp.TypeParameter("S", bound=tst.StringType()) + union = tst.UnionType([tst.NumberType(), type_param]) + new_union = union.to_type_variable_free(tst.TypeScriptBuiltinFactory()) + assert new_union == tst.UnionType([tst.NumberType(), tst.StringType()])