diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index 57de4efc1..aae1135ca 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -19,7 +19,7 @@ jobs: - name: Build artifacts run: | - pipx run poetry build + poetry build - name: Archive artifacts uses: actions/upload-artifact@v3 @@ -57,6 +57,7 @@ jobs: pip install --no-input $artifact python -c 'import lava.magma.compiler.subcompilers' python -c 'import lava.magma.core.model' + python -c 'from lava.magma.runtime.message_infrastructure.multiprocessing import MultiProcessing' pip uninstall -y lava-nc deactivate rm -rf artifact-test @@ -69,11 +70,11 @@ jobs: pip install --no-input $artifact python -c 'import lava.magma.compiler.subcompilers' python -c 'import lava.magma.core.model' + python -c 'from lava.magma.runtime.message_infrastructure.multiprocessing import MultiProcessing' pip uninstall -y lava-nc deactivate rm -rf artifact-test - test-artifact-use: name: Test Artifact With Unit Tests runs-on: ubuntu-latest diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b0374b531..1650c8157 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,12 +22,26 @@ jobs: - name: setup CI uses: lava-nc/ci-setup-composite-action@v1.2 with: - repository: 'Lava' + repository: 'Lava' - name: Run flakeheaven (flake8) run: | poetry run flakeheaven lint src/lava tests/ poetry run find tutorials/ -name '*.py' -exec flakeheaven lint {} \+ + + - name: Run cpplint (cpplint) + run: | + poetry run cpplint --recursive --quiet \ + --repository=src/lava/magma/runtime/_c_message_infrastructure \ + --root=csrc \ + --exclude=src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/protos \ + src/lava/magma/runtime/_c_message_infrastructure/ \ + src/lava/magma/runtime/_c_message_infrastructure/test/ + + poetry run cpplint --recursive --quiet \ + --repository=src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_cpp \ + --root=src \ + src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_cpp/src security-lint: name: Security Lint Code @@ -47,7 +61,7 @@ jobs: with: targets: | src/lava/. - options: "-r --format custom --msg-template '{abspath}:{line}: {test_id}[bandit]: {severity}: {msg}'" + options: "-r --exclude build --format custom --msg-template '{abspath}:{line}: {test_id}[bandit]: {severity}: {msg}'" unit-tests: name: Unit Test Code + Coverage @@ -82,3 +96,22 @@ jobs: name: coverage path: coverage.xml retention-days: 30 + + msg-infr-unit-tests: + name: Message Infrastructure Unit Test + runs-on: ubuntu-latest + env: + DEBUG: 1 + + steps: + - uses: actions/checkout@v3 + with: + lfs: true + + - name: setup CI + uses: lava-nc/ci-setup-composite-action@v1.2 + with: + repository: 'Lava' + + - name: Run message infrastructure cpp unit tests + run: ./build/test/test_messaging_infrastructure diff --git a/.gitignore b/.gitignore index cd5d1d84b..f7dc204e5 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ src/lava/utils/dataloader/mnist.npy # Distribution / packaging .Python build/ +log/ develop-eggs/ dist/ downloads/ diff --git a/poetry.lock b/poetry.lock index 40a7041ff..e32707974 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4,6 +4,7 @@ name = "alabaster" version = "0.7.13" description = "A configurable sidebar-enabled Sphinx theme" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -15,6 +16,7 @@ files = [ name = "appnope" version = "0.1.3" description = "Disable App Nap on macOS >= 10.9" +category = "dev" optional = false python-versions = "*" files = [ @@ -26,6 +28,7 @@ files = [ name = "argparse" version = "1.4.0" description = "Python command-line parsing library" +category = "dev" optional = false python-versions = "*" files = [ @@ -37,6 +40,7 @@ files = [ name = "asteval" version = "0.9.31" description = "Safe, minimalistic evaluator of python expression using ast module" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -54,6 +58,7 @@ test = ["coverage", "pytest", "pytest-cov"] name = "asttokens" version = "2.2.1" description = "Annotate AST trees with source code positions" +category = "dev" optional = false python-versions = "*" files = [ @@ -71,6 +76,7 @@ test = ["astroid", "pytest"] name = "attrs" version = "23.1.0" description = "Classes Without Boilerplate" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -89,6 +95,7 @@ tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pyte name = "autopep8" version = "1.6.0" description = "A tool that automatically formats Python code to conform to the PEP 8 style guide" +category = "dev" optional = false python-versions = "*" files = [ @@ -104,6 +111,7 @@ toml = "*" name = "babel" version = "2.12.1" description = "Internationalization utilities" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -118,6 +126,7 @@ pytz = {version = ">=2015.7", markers = "python_version < \"3.9\""} name = "backcall" version = "0.2.0" description = "Specifications for callback functions passed in to an API" +category = "dev" optional = false python-versions = "*" files = [ @@ -129,6 +138,7 @@ files = [ name = "bandit" version = "1.7.4" description = "Security oriented static analyser for python code." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -151,6 +161,7 @@ yaml = ["PyYAML"] name = "beautifulsoup4" version = "4.12.2" description = "Screen-scraping library" +category = "dev" optional = false python-versions = ">=3.6.0" files = [ @@ -169,6 +180,7 @@ lxml = ["lxml"] name = "bleach" version = "6.0.0" description = "An easy safelist-based HTML-sanitizing tool." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -187,6 +199,7 @@ css = ["tinycss2 (>=1.1.0,<1.2)"] name = "build" version = "0.10.0" description = "A simple, correct Python build frontend" +category = "dev" optional = false python-versions = ">= 3.7" files = [ @@ -210,6 +223,7 @@ virtualenv = ["virtualenv (>=20.0.35)"] name = "cachecontrol" version = "0.12.14" description = "httplib2 caching for requests" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -230,6 +244,7 @@ redis = ["redis (>=2.10.5)"] name = "certifi" version = "2023.7.22" description = "Python package for providing Mozilla's CA Bundle." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -241,6 +256,7 @@ files = [ name = "cffi" version = "1.15.1" description = "Foreign Function Interface for Python calling C code." +category = "dev" optional = false python-versions = "*" files = [ @@ -317,6 +333,7 @@ pycparser = "*" name = "charset-normalizer" version = "3.2.0" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." +category = "dev" optional = false python-versions = ">=3.7.0" files = [ @@ -401,6 +418,7 @@ files = [ name = "cleo" version = "2.0.1" description = "Cleo allows you to create beautiful and testable command-line interfaces." +category = "dev" optional = false python-versions = ">=3.7,<4.0" files = [ @@ -416,6 +434,7 @@ rapidfuzz = ">=2.2.0,<3.0.0" name = "colorama" version = "0.4.6" description = "Cross-platform colored terminal text." +category = "dev" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ @@ -427,6 +446,7 @@ files = [ name = "comm" version = "0.1.3" description = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -446,6 +466,7 @@ typing = ["mypy (>=0.990)"] name = "contourpy" version = "1.1.0" description = "Python library for calculating contours of 2D quadrilateral grids" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -508,6 +529,7 @@ test-no-images = ["pytest", "pytest-cov", "wurlitzer"] name = "coverage" version = "6.5.0" description = "Code coverage measurement for Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -569,10 +591,27 @@ tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.1 [package.extras] toml = ["tomli"] +[[package]] +name = "cpplint" +version = "1.6.1" +description = "Automated checker to ensure C++ files follow Google's style guide" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "cpplint-1.6.1-py3-none-any.whl", hash = "sha256:00ddc86d6e4de2a9dcfa272402dcbe21593363a93b7c475bc391e335062f34b1"}, + {file = "cpplint-1.6.1.tar.gz", hash = "sha256:d430ce8f67afc1839340e60daa89e90de08b874bc27149833077bba726dfc13a"}, +] + +[package.extras] +dev = ["configparser (<=3.7.4)", "flake8 (>=4.0.1)", "flake8-polyfill", "importlib-metadata (>=0.12)", "pylint (>=2.11.0)", "pyparsing (<3)", "pytest (>=4.6,<5.0)", "pytest-cov", "testfixtures", "tox (>=3.0.0)", "tox-pyenv", "zipp (<=0.5.1)"] +test = ["configparser (<=3.7.4)", "pyparsing (<3)", "pytest (>=4.6,<5.0)", "pytest-cov", "testfixtures", "zipp (<=0.5.1)"] + [[package]] name = "crashtest" version = "0.4.1" description = "Manage Python errors with ease" +category = "dev" optional = false python-versions = ">=3.7,<4.0" files = [ @@ -584,6 +623,7 @@ files = [ name = "cryptography" version = "41.0.4" description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -629,6 +669,7 @@ test-randomorder = ["pytest-randomly"] name = "cycler" version = "0.11.0" description = "Composable style cycles" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -640,6 +681,7 @@ files = [ name = "darglint" version = "1.8.1" description = "A utility for ensuring Google-style docstrings stay up to date with the source code." +category = "dev" optional = false python-versions = ">=3.6,<4.0" files = [ @@ -651,6 +693,7 @@ files = [ name = "debugpy" version = "1.6.7" description = "An implementation of the Debug Adapter Protocol for Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -678,6 +721,7 @@ files = [ name = "decorator" version = "5.1.1" description = "Decorators for Humans" +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -689,6 +733,7 @@ files = [ name = "defusedxml" version = "0.7.1" description = "XML bomb protection for Python stdlib modules" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -700,6 +745,7 @@ files = [ name = "distlib" version = "0.3.7" description = "Distribution utilities" +category = "dev" optional = false python-versions = "*" files = [ @@ -711,6 +757,7 @@ files = [ name = "docutils" version = "0.17.1" description = "Docutils -- Python Documentation Utilities" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -722,6 +769,7 @@ files = [ name = "dulwich" version = "0.21.5" description = "Python Git Library" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -796,6 +844,7 @@ pgp = ["gpg"] name = "entrypoints" version = "0.4" description = "Discover and load entry points from installed packages." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -807,6 +856,7 @@ files = [ name = "eradicate" version = "2.3.0" description = "Removes commented-out code." +category = "dev" optional = false python-versions = "*" files = [ @@ -818,6 +868,7 @@ files = [ name = "exceptiongroup" version = "1.1.2" description = "Backport of PEP 654 (exception groups)" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -832,6 +883,7 @@ test = ["pytest (>=6)"] name = "executing" version = "1.2.0" description = "Get the currently executing AST node of a frame, and other information" +category = "dev" optional = false python-versions = "*" files = [ @@ -846,6 +898,7 @@ tests = ["asttokens", "littleutils", "pytest", "rich"] name = "fastjsonschema" version = "2.18.0" description = "Fastest Python implementation of JSON schema" +category = "dev" optional = false python-versions = "*" files = [ @@ -860,6 +913,7 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc name = "filelock" version = "3.12.2" description = "A platform independent file lock." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -875,6 +929,7 @@ testing = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "diff-cover (>=7.5)", "p name = "flake8" version = "4.0.1" description = "the modular source code checker: pep8 pyflakes and co" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -891,6 +946,7 @@ pyflakes = ">=2.4.0,<2.5.0" name = "flake8-bandit" version = "3.0.0" description = "Automated security testing with bandit and flake8." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -908,6 +964,7 @@ pycodestyle = "*" name = "flake8-bugbear" version = "22.12.6" description = "A plugin for flake8 finding likely bugs and design problems in your program. Contains warnings that don't belong in pyflakes and pycodestyle." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -926,6 +983,7 @@ dev = ["coverage", "hypothesis", "hypothesmith (>=0.2)", "pre-commit", "tox"] name = "flake8-builtins" version = "1.5.3" description = "Check for python builtins being used as variables or parameters." +category = "dev" optional = false python-versions = "*" files = [ @@ -943,6 +1001,7 @@ test = ["coverage", "coveralls", "mock", "pytest", "pytest-cov"] name = "flake8-comprehensions" version = "3.14.0" description = "A flake8 plugin to help you write better list/set/dict comprehensions." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -957,6 +1016,7 @@ flake8 = ">=3.0,<3.2.0 || >3.2.0" name = "flake8-docstrings" version = "1.7.0" description = "Extension for flake8 which uses pydocstyle to check docstrings" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -972,6 +1032,7 @@ pydocstyle = ">=2.1" name = "flake8-eradicate" version = "1.4.0" description = "Flake8 plugin to find commented out code" +category = "dev" optional = false python-versions = ">=3.7,<4.0" files = [ @@ -988,6 +1049,7 @@ flake8 = ">=3.5,<6" name = "flake8-isort" version = "4.2.0" description = "flake8 plugin that integrates isort ." +category = "dev" optional = false python-versions = "*" files = [ @@ -1006,6 +1068,7 @@ test = ["pytest-cov"] name = "flake8-mutable" version = "1.2.0" description = "mutable defaults flake8 extension" +category = "dev" optional = false python-versions = "*" files = [ @@ -1020,6 +1083,7 @@ flake8 = "*" name = "flake8-plugin-utils" version = "1.3.3" description = "The package provides base classes and utils for flake8 plugin writing" +category = "dev" optional = false python-versions = ">=3.6,<4.0" files = [ @@ -1031,6 +1095,7 @@ files = [ name = "flake8-polyfill" version = "1.0.2" description = "Polyfill package for Flake8 plugins" +category = "dev" optional = false python-versions = "*" files = [ @@ -1045,6 +1110,7 @@ flake8 = "*" name = "flake8-pytest-style" version = "1.7.2" description = "A flake8 plugin checking common style issues or inconsistencies with pytest-based tests." +category = "dev" optional = false python-versions = ">=3.7.2,<4.0.0" files = [ @@ -1059,6 +1125,7 @@ flake8-plugin-utils = ">=1.3.2,<2.0.0" name = "flake8-spellcheck" version = "0.25.0" description = "Spellcheck variables, comments and docstrings" +category = "dev" optional = false python-versions = "*" files = [ @@ -1073,6 +1140,7 @@ flake8 = ">3.0.0" name = "flakeheaven" version = "3.3.0" description = "FlakeHeaven is a [Flake8](https://gitlab.com/pycqa/flake8) wrapper to make it cool." +category = "dev" optional = false python-versions = ">=3.7,<4.0" files = [ @@ -1095,6 +1163,7 @@ docs = ["alabaster", "myst-parser (>=0.18.0,<0.19.0)", "pygments-github-lexers", name = "fonttools" version = "4.41.1" description = "Tools to manipulate font files" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1152,6 +1221,7 @@ woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"] name = "gitdb" version = "4.0.10" description = "Git Object Database" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1166,6 +1236,7 @@ smmap = ">=3.0.1,<6" name = "gitpython" version = "3.1.37" description = "GitPython is a Python library used to interact with Git repositories" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1183,6 +1254,7 @@ test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mypy", "pre-commit" name = "html5lib" version = "1.1" description = "HTML parser based on the WHATWG HTML specification" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -1204,6 +1276,7 @@ lxml = ["lxml"] name = "idna" version = "3.4" description = "Internationalized Domain Names in Applications (IDNA)" +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -1215,6 +1288,7 @@ files = [ name = "imagesize" version = "1.4.1" description = "Getting image size from png/jpeg/jpeg2000/gif file" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -1226,6 +1300,7 @@ files = [ name = "importlib-metadata" version = "6.8.0" description = "Read metadata from Python packages" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1245,6 +1320,7 @@ testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs name = "importlib-resources" version = "6.0.0" description = "Read resources from Python packages" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1263,6 +1339,7 @@ testing = ["pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", name = "iniconfig" version = "2.0.0" description = "brain-dead simple config-ini parsing" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1274,6 +1351,7 @@ files = [ name = "installer" version = "0.7.0" description = "A library for installing Python wheels." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1283,13 +1361,14 @@ files = [ [[package]] name = "ipykernel" -version = "6.24.0" +version = "6.25.0" description = "IPython Kernel for Jupyter" +category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "ipykernel-6.24.0-py3-none-any.whl", hash = "sha256:2f5fffc7ad8f1fd5aadb4e171ba9129d9668dbafa374732cf9511ada52d6547f"}, - {file = "ipykernel-6.24.0.tar.gz", hash = "sha256:29cea0a716b1176d002a61d0b0c851f34536495bc4ef7dd0222c88b41b816123"}, + {file = "ipykernel-6.25.0-py3-none-any.whl", hash = "sha256:f0042e867ac3f6bca1679e6a88cbd6a58ed93a44f9d0866aecde6efe8de76659"}, + {file = "ipykernel-6.25.0.tar.gz", hash = "sha256:e342ce84712861be4b248c4a73472be4702c1b0dd77448bfd6bcfb3af9d5ddf9"}, ] [package.dependencies] @@ -1298,7 +1377,7 @@ comm = ">=0.1.1" debugpy = ">=1.6.5" ipython = ">=7.23.1" jupyter-client = ">=6.1.12" -jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" matplotlib-inline = ">=0.1" nest-asyncio = "*" packaging = "*" @@ -1318,6 +1397,7 @@ test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio" name = "ipython" version = "8.12.2" description = "IPython: Productive Interactive Computing" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1357,6 +1437,7 @@ test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.21)", "pa name = "isort" version = "5.12.0" description = "A Python utility / library to sort Python imports." +category = "dev" optional = false python-versions = ">=3.8.0" files = [ @@ -1374,6 +1455,7 @@ requirements-deprecated-finder = ["pip-api", "pipreqs"] name = "jaraco-classes" version = "3.3.0" description = "Utility functions for Python class constructs" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1390,27 +1472,29 @@ testing = ["pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", [[package]] name = "jedi" -version = "0.18.2" +version = "0.19.0" description = "An autocompletion tool for Python that can be used for text editors." +category = "dev" optional = false python-versions = ">=3.6" files = [ - {file = "jedi-0.18.2-py2.py3-none-any.whl", hash = "sha256:203c1fd9d969ab8f2119ec0a3342e0b49910045abe6af0a3ae83a5764d54639e"}, - {file = "jedi-0.18.2.tar.gz", hash = "sha256:bae794c30d07f6d910d32a7048af09b5a39ed740918da923c6b780790ebac612"}, + {file = "jedi-0.19.0-py2.py3-none-any.whl", hash = "sha256:cb8ce23fbccff0025e9386b5cf85e892f94c9b822378f8da49970471335ac64e"}, + {file = "jedi-0.19.0.tar.gz", hash = "sha256:bcf9894f1753969cbac8022a8c2eaee06bfa3724e4192470aaffe7eb6272b0c4"}, ] [package.dependencies] -parso = ">=0.8.0,<0.9.0" +parso = ">=0.8.3,<0.9.0" [package.extras] docs = ["Jinja2 (==2.11.3)", "MarkupSafe (==1.1.1)", "Pygments (==2.8.1)", "alabaster (==0.7.12)", "babel (==2.9.1)", "chardet (==4.0.0)", "commonmark (==0.8.1)", "docutils (==0.17.1)", "future (==0.18.2)", "idna (==2.10)", "imagesize (==1.2.0)", "mock (==1.0.1)", "packaging (==20.9)", "pyparsing (==2.4.7)", "pytz (==2021.1)", "readthedocs-sphinx-ext (==2.1.4)", "recommonmark (==0.5.0)", "requests (==2.25.1)", "six (==1.15.0)", "snowballstemmer (==2.1.0)", "sphinx (==1.8.5)", "sphinx-rtd-theme (==0.4.3)", "sphinxcontrib-serializinghtml (==1.1.4)", "sphinxcontrib-websupport (==1.2.4)", "urllib3 (==1.26.4)"] -qa = ["flake8 (==3.8.3)", "mypy (==0.782)"] +qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"] testing = ["Django (<3.1)", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] [[package]] name = "jeepney" version = "0.8.0" description = "Low-level, pure Python DBus protocol wrapper." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1426,6 +1510,7 @@ trio = ["async_generator", "trio"] name = "jinja2" version = "3.0.3" description = "A very fast and expressive template engine." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1454,6 +1539,7 @@ files = [ name = "jsonschema" version = "4.18.4" description = "An implementation of JSON Schema validation for Python" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1477,6 +1563,7 @@ format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339- name = "jsonschema-specifications" version = "2023.7.1" description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1492,6 +1579,7 @@ referencing = ">=0.28.0" name = "jupyter-client" version = "8.3.0" description = "Jupyter protocol implementation and client libraries" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1501,7 +1589,7 @@ files = [ [package.dependencies] importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""} -jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" python-dateutil = ">=2.8.2" pyzmq = ">=23.0" tornado = ">=6.2" @@ -1515,6 +1603,7 @@ test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko", "pre-commit", "pyt name = "jupyter-core" version = "5.3.1" description = "Jupyter core package. A base package on which Jupyter projects rely." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1535,6 +1624,7 @@ test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"] name = "jupyterlab-pygments" version = "0.2.2" description = "Pygments theme using JupyterLab CSS variables" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1546,6 +1636,7 @@ files = [ name = "keyring" version = "23.13.1" description = "Store and access your passwords safely." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1570,6 +1661,7 @@ testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-chec name = "kiwisolver" version = "1.4.4" description = "A fast implementation of the Cassowary constraint solver" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1647,6 +1739,7 @@ files = [ name = "linecache2" version = "1.0.0" description = "Backports of the linecache module" +category = "dev" optional = false python-versions = "*" files = [ @@ -1658,6 +1751,7 @@ files = [ name = "lockfile" version = "0.12.2" description = "Platform-independent file locking module" +category = "dev" optional = false python-versions = "*" files = [ @@ -1669,6 +1763,7 @@ files = [ name = "markupsafe" version = "2.1.3" description = "Safely add untrusted strings to HTML/XML markup." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1738,6 +1833,7 @@ files = [ name = "matplotlib" version = "3.7.2" description = "Python plotting package" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1800,6 +1896,7 @@ python-dateutil = ">=2.7" name = "matplotlib-inline" version = "0.1.6" description = "Inline Matplotlib backend for Jupyter" +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -1814,6 +1911,7 @@ traitlets = "*" name = "mccabe" version = "0.6.1" description = "McCabe checker, plugin for flake8" +category = "dev" optional = false python-versions = "*" files = [ @@ -1825,6 +1923,7 @@ files = [ name = "mistune" version = "2.0.5" description = "A sane Markdown parser with useful plugins and renderers" +category = "dev" optional = false python-versions = "*" files = [ @@ -1836,6 +1935,7 @@ files = [ name = "more-itertools" version = "10.0.0" description = "More routines for operating on iterables, beyond itertools" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1847,6 +1947,7 @@ files = [ name = "msgpack" version = "1.0.5" description = "MessagePack serializer" +category = "dev" optional = false python-versions = "*" files = [ @@ -1919,6 +2020,7 @@ files = [ name = "nbclient" version = "0.8.0" description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." +category = "dev" optional = false python-versions = ">=3.8.0" files = [ @@ -1928,7 +2030,7 @@ files = [ [package.dependencies] jupyter-client = ">=6.1.12" -jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" nbformat = ">=5.1" traitlets = ">=5.4" @@ -1941,6 +2043,7 @@ test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>= name = "nbconvert" version = "7.2.10" description = "Converting Jupyter Notebooks" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1979,6 +2082,7 @@ webpdf = ["pyppeteer (>=1,<1.1)"] name = "nbformat" version = "5.9.1" description = "The Jupyter Notebook format" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1998,19 +2102,21 @@ test = ["pep440", "pre-commit", "pytest", "testpath"] [[package]] name = "nest-asyncio" -version = "1.5.6" +version = "1.5.7" description = "Patch asyncio to allow nested event loops" +category = "dev" optional = false python-versions = ">=3.5" files = [ - {file = "nest_asyncio-1.5.6-py3-none-any.whl", hash = "sha256:b9a953fb40dceaa587d109609098db21900182b16440652454a146cffb06e8b8"}, - {file = "nest_asyncio-1.5.6.tar.gz", hash = "sha256:d267cc1ff794403f7df692964d1d2a3fa9418ffea2a3f6859a439ff482fef290"}, + {file = "nest_asyncio-1.5.7-py3-none-any.whl", hash = "sha256:5301c82941b550b3123a1ea772ba9a1c80bad3a182be8c1a5ae6ad3be57a9657"}, + {file = "nest_asyncio-1.5.7.tar.gz", hash = "sha256:6a80f7b98f24d9083ed24608977c09dd608d83f91cccc24c9d2cba6d10e01c10"}, ] [[package]] name = "networkx" version = "2.8.7" description = "Python package for creating and manipulating graphs and networks" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2029,6 +2135,7 @@ test = ["codecov (>=2.1)", "pytest (>=7.1)", "pytest-cov (>=3.0)"] name = "numpy" version = "1.24.4" description = "Fundamental package for array computing in Python" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2066,6 +2173,7 @@ files = [ name = "packaging" version = "23.1" description = "Core utilities for Python packages" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2077,6 +2185,7 @@ files = [ name = "pandas" version = "1.5.3" description = "Powerful data structures for data analysis, time series, and statistics" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -2124,6 +2233,7 @@ test = ["hypothesis (>=5.5.3)", "pytest (>=6.0)", "pytest-xdist (>=1.31)"] name = "pandocfilters" version = "1.5.0" description = "Utilities for writing pandoc filters in python" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -2135,6 +2245,7 @@ files = [ name = "parso" version = "0.8.3" description = "A Python Parser" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2150,6 +2261,7 @@ testing = ["docopt", "pytest (<6.0.0)"] name = "pbr" version = "5.11.1" description = "Python Build Reasonableness" +category = "dev" optional = false python-versions = ">=2.6" files = [ @@ -2161,6 +2273,7 @@ files = [ name = "pep8-naming" version = "0.12.1" description = "Check PEP-8 naming conventions, plugin for flake8" +category = "dev" optional = false python-versions = "*" files = [ @@ -2176,6 +2289,7 @@ flake8-polyfill = ">=1.0.2,<2" name = "pexpect" version = "4.8.0" description = "Pexpect allows easy control of interactive console applications." +category = "dev" optional = false python-versions = "*" files = [ @@ -2190,6 +2304,7 @@ ptyprocess = ">=0.5" name = "pickleshare" version = "0.7.5" description = "Tiny 'shelve'-like database with concurrency support" +category = "dev" optional = false python-versions = "*" files = [ @@ -2201,6 +2316,7 @@ files = [ name = "pillow" version = "10.0.1" description = "Python Imaging Library (Fork)" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -2268,6 +2384,7 @@ tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "pa name = "pkginfo" version = "1.9.6" description = "Query metadata from sdists / bdists / installed packages." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2282,6 +2399,7 @@ testing = ["pytest", "pytest-cov"] name = "pkgutil-resolve-name" version = "1.3.10" description = "Resolve a name to an object." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2291,23 +2409,25 @@ files = [ [[package]] name = "platformdirs" -version = "3.9.1" +version = "3.10.0" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "platformdirs-3.9.1-py3-none-any.whl", hash = "sha256:ad8291ae0ae5072f66c16945166cb11c63394c7a3ad1b1bc9828ca3162da8c2f"}, - {file = "platformdirs-3.9.1.tar.gz", hash = "sha256:1b42b450ad933e981d56e59f1b97495428c9bd60698baab9f3eb3d00d5822421"}, + {file = "platformdirs-3.10.0-py3-none-any.whl", hash = "sha256:d7c24979f292f916dc9cbf8648319032f551ea8c49a4c9bf2fb556a02070ec1d"}, + {file = "platformdirs-3.10.0.tar.gz", hash = "sha256:b45696dab2d7cc691a3226759c0d3b00c47c8b6e293d96f6436f733303f77f6d"}, ] [package.extras] -docs = ["furo (>=2023.5.20)", "proselint (>=0.13)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"] -test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)"] +docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.1)", "sphinx-autodoc-typehints (>=1.24)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)"] [[package]] name = "pluggy" version = "1.2.0" description = "plugin and hook calling mechanisms for python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2323,6 +2443,7 @@ testing = ["pytest", "pytest-benchmark"] name = "poetry" version = "1.5.1" description = "Python dependency management and packaging made easy." +category = "dev" optional = false python-versions = ">=3.7,<4.0" files = [ @@ -2364,6 +2485,7 @@ xattr = {version = ">=0.10.0,<0.11.0", markers = "sys_platform == \"darwin\""} name = "poetry-core" version = "1.6.1" description = "Poetry PEP 517 Build Backend" +category = "dev" optional = false python-versions = ">=3.7,<4.0" files = [ @@ -2375,6 +2497,7 @@ files = [ name = "poetry-plugin-export" version = "1.4.0" description = "Poetry plugin to export the dependencies to various formats" +category = "dev" optional = false python-versions = ">=3.7,<4.0" files = [ @@ -2390,6 +2513,7 @@ poetry-core = ">=1.6.0,<2.0.0" name = "prompt-toolkit" version = "3.0.39" description = "Library for building powerful interactive command lines in Python" +category = "dev" optional = false python-versions = ">=3.7.0" files = [ @@ -2404,6 +2528,7 @@ wcwidth = "*" name = "psutil" version = "5.9.5" description = "Cross-platform lib for process and system monitoring in Python." +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -2430,6 +2555,7 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] name = "ptyprocess" version = "0.7.0" description = "Run a subprocess in a pseudo terminal" +category = "dev" optional = false python-versions = "*" files = [ @@ -2441,6 +2567,7 @@ files = [ name = "pure-eval" version = "0.2.2" description = "Safely evaluate AST nodes without side effects" +category = "dev" optional = false python-versions = "*" files = [ @@ -2451,10 +2578,41 @@ files = [ [package.extras] tests = ["pytest"] +[[package]] +name = "pybind11" +version = "2.11.1" +description = "Seamless operability between C++11 and Python" +category = "main" +optional = false +python-versions = ">=3.6" +files = [ + {file = "pybind11-2.11.1-py3-none-any.whl", hash = "sha256:33cdd02a6453380dd71cc70357ce388ad1ee8d32bd0e38fc22b273d050aa29b3"}, + {file = "pybind11-2.11.1.tar.gz", hash = "sha256:00cd59116a6e8155aecd9174f37ba299d1d397ed4a6b86ac1dfe01b3e40f2cc4"}, +] + +[package.dependencies] +pybind11-global = {version = "2.11.1", optional = true, markers = "extra == \"global\""} + +[package.extras] +global = ["pybind11-global (==2.11.1)"] + +[[package]] +name = "pybind11-global" +version = "2.11.1" +description = "Seamless operability between C++11 and Python" +category = "main" +optional = false +python-versions = ">=3.6" +files = [ + {file = "pybind11_global-2.11.1-py3-none-any.whl", hash = "sha256:9664c675af3225b86f0d93873fe76f6c2d2966bc9ed64276fb0f73dfcb181806"}, + {file = "pybind11_global-2.11.1.tar.gz", hash = "sha256:1ba797947bd375f48717377117fc5d31f2c8b26cf8f5e2a907ab38f380b75324"}, +] + [[package]] name = "pycodestyle" version = "2.8.0" description = "Python style guide checker" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -2466,6 +2624,7 @@ files = [ name = "pycparser" version = "2.21" description = "C parser in Python" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -2477,6 +2636,7 @@ files = [ name = "pydocstyle" version = "6.3.0" description = "Python docstring style checker" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2494,6 +2654,7 @@ toml = ["tomli (>=1.2.3)"] name = "pyflakes" version = "2.4.0" description = "passive checker of Python programs" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -2505,6 +2666,7 @@ files = [ name = "pygments" version = "2.15.1" description = "Pygments is a syntax highlighting package written in Python." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2519,6 +2681,7 @@ plugins = ["importlib-metadata"] name = "pyparsing" version = "3.0.9" description = "pyparsing module - Classes and methods to define and execute parsing grammars" +category = "dev" optional = false python-versions = ">=3.6.8" files = [ @@ -2533,6 +2696,7 @@ diagrams = ["jinja2", "railroad-diagrams"] name = "pyproject-hooks" version = "1.0.0" description = "Wrappers to call pyproject.toml-based build backend hooks." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2547,6 +2711,7 @@ tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} name = "pytest" version = "7.4.0" description = "pytest: simple powerful testing with Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2569,6 +2734,7 @@ testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "no name = "pytest-cov" version = "3.0.0" description = "Pytest plugin for measuring coverage." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2587,6 +2753,7 @@ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtuale name = "python-dateutil" version = "2.8.2" description = "Extensions to the standard Python datetime module" +category = "dev" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ @@ -2601,6 +2768,7 @@ six = ">=1.5" name = "pytz" version = "2023.3" description = "World timezone definitions, modern and historical" +category = "dev" optional = false python-versions = "*" files = [ @@ -2612,6 +2780,7 @@ files = [ name = "pywin32" version = "306" description = "Python for Window Extensions" +category = "dev" optional = false python-versions = "*" files = [ @@ -2635,6 +2804,7 @@ files = [ name = "pywin32-ctypes" version = "0.2.2" description = "A (partial) reimplementation of pywin32 using ctypes/cffi" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2646,6 +2816,7 @@ files = [ name = "pyyaml" version = "6.0.1" description = "YAML parser and emitter for Python" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2705,6 +2876,7 @@ files = [ name = "pyzmq" version = "25.1.0" description = "Python bindings for 0MQ" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2794,6 +2966,7 @@ cffi = {version = "*", markers = "implementation_name == \"pypy\""} name = "rapidfuzz" version = "2.15.1" description = "rapid fuzzy string matching" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2898,6 +3071,7 @@ full = ["numpy"] name = "referencing" version = "0.30.0" description = "JSON Referencing + Python" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -2913,6 +3087,7 @@ rpds-py = ">=0.7.0" name = "requests" version = "2.31.0" description = "Python HTTP for Humans." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2934,6 +3109,7 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] name = "requests-toolbelt" version = "1.0.0" description = "A utility belt for advanced users of python-requests" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -2948,6 +3124,7 @@ requests = ">=2.0.1,<3.0.0" name = "rpds-py" version = "0.9.2" description = "Python bindings to Rust's persistent data structures (rpds)" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -3101,6 +3278,7 @@ tests = ["black (>=23.3.0)", "matplotlib (>=3.1.3)", "mypy (>=1.3)", "numpydoc ( name = "scipy" version = "1.10.1" description = "Fundamental algorithms for scientific computing in Python" +category = "main" optional = false python-versions = "<3.12,>=3.8" files = [ @@ -3139,6 +3317,7 @@ test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeo name = "secretstorage" version = "3.3.3" description = "Python bindings to FreeDesktop.org Secret Service API" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3154,6 +3333,7 @@ jeepney = ">=0.6" name = "shellingham" version = "1.5.0.post1" description = "Tool to Detect Surrounding Shell" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3165,6 +3345,7 @@ files = [ name = "six" version = "1.16.0" description = "Python 2 and 3 compatibility utilities" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -3176,6 +3357,7 @@ files = [ name = "smmap" version = "5.0.0" description = "A pure Python implementation of a sliding window memory map manager" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3187,6 +3369,7 @@ files = [ name = "snowballstemmer" version = "2.2.0" description = "This package provides 29 stemmers for 28 languages generated from Snowball algorithms." +category = "dev" optional = false python-versions = "*" files = [ @@ -3198,6 +3381,7 @@ files = [ name = "soupsieve" version = "2.4.1" description = "A modern CSS selector implementation for Beautiful Soup." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3209,6 +3393,7 @@ files = [ name = "sphinx" version = "4.5.0" description = "Python documentation generator" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3244,6 +3429,7 @@ test = ["cython", "html5lib", "pytest", "pytest-cov", "typed-ast"] name = "sphinx-rtd-theme" version = "1.2.2" description = "Read the Docs theme for Sphinx" +category = "dev" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ @@ -3263,6 +3449,7 @@ dev = ["bump2version", "sphinxcontrib-httpdomain", "transifex-client", "wheel"] name = "sphinx-tabs" version = "3.4.0" description = "Tabbed views for Sphinx" +category = "dev" optional = false python-versions = "~=3.7" files = [ @@ -3284,6 +3471,7 @@ testing = ["bs4", "coverage", "pygments", "pytest (>=7.1,<8)", "pytest-cov", "py name = "sphinxcontrib-applehelp" version = "1.0.4" description = "sphinxcontrib-applehelp is a Sphinx extension which outputs Apple help books" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -3299,6 +3487,7 @@ test = ["pytest"] name = "sphinxcontrib-devhelp" version = "1.0.2" description = "sphinxcontrib-devhelp is a sphinx extension which outputs Devhelp document." +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -3314,6 +3503,7 @@ test = ["pytest"] name = "sphinxcontrib-htmlhelp" version = "2.0.1" description = "sphinxcontrib-htmlhelp is a sphinx extension which renders HTML help files" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -3329,6 +3519,7 @@ test = ["html5lib", "pytest"] name = "sphinxcontrib-jquery" version = "4.1" description = "Extension to include jQuery on newer Sphinx releases" +category = "dev" optional = false python-versions = ">=2.7" files = [ @@ -3343,6 +3534,7 @@ Sphinx = ">=1.8" name = "sphinxcontrib-jsmath" version = "1.0.1" description = "A sphinx extension which renders display math in HTML via JavaScript" +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -3357,6 +3549,7 @@ test = ["flake8", "mypy", "pytest"] name = "sphinxcontrib-qthelp" version = "1.0.3" description = "sphinxcontrib-qthelp is a sphinx extension which outputs QtHelp document." +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -3372,6 +3565,7 @@ test = ["pytest"] name = "sphinxcontrib-serializinghtml" version = "1.1.5" description = "sphinxcontrib-serializinghtml is a sphinx extension which outputs \"serialized\" HTML files (json and pickle)." +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -3387,6 +3581,7 @@ test = ["pytest"] name = "stack-data" version = "0.6.2" description = "Extract data from python stack frames and tracebacks for informative displays" +category = "dev" optional = false python-versions = "*" files = [ @@ -3406,6 +3601,7 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] name = "stevedore" version = "5.1.0" description = "Manage dynamic plugins for Python applications" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -3431,6 +3627,7 @@ files = [ name = "tinycss2" version = "1.2.1" description = "A tiny CSS parser" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3449,6 +3646,7 @@ test = ["flake8", "isort", "pytest"] name = "toml" version = "0.10.2" description = "Python Library for Tom's Obvious, Minimal Language" +category = "dev" optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -3460,6 +3658,7 @@ files = [ name = "tomli" version = "2.0.1" description = "A lil' TOML parser" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3469,19 +3668,21 @@ files = [ [[package]] name = "tomlkit" -version = "0.11.8" +version = "0.12.1" description = "Style preserving TOML library" +category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "tomlkit-0.11.8-py3-none-any.whl", hash = "sha256:8c726c4c202bdb148667835f68d68780b9a003a9ec34167b6c673b38eff2a171"}, - {file = "tomlkit-0.11.8.tar.gz", hash = "sha256:9330fc7faa1db67b541b28e62018c17d20be733177d290a13b24c62d1614e0c3"}, + {file = "tomlkit-0.12.1-py3-none-any.whl", hash = "sha256:712cbd236609acc6a3e2e97253dfc52d4c2082982a88f61b640ecf0817eab899"}, + {file = "tomlkit-0.12.1.tar.gz", hash = "sha256:38e1ff8edb991273ec9f6181244a6a391ac30e9f5098e7535640ea6be97a7c86"}, ] [[package]] name = "tornado" version = "6.3.3" description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." +category = "dev" optional = false python-versions = ">= 3.8" files = [ @@ -3502,6 +3703,7 @@ files = [ name = "traceback2" version = "1.4.0" description = "Backports of the traceback module" +category = "dev" optional = false python-versions = "*" files = [ @@ -3516,6 +3718,7 @@ linecache2 = "*" name = "traitlets" version = "5.9.0" description = "Traitlets Python configuration system" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3531,6 +3734,7 @@ test = ["argcomplete (>=2.0)", "pre-commit", "pytest", "pytest-mock"] name = "trove-classifiers" version = "2023.7.6" description = "Canonical source for classifiers on PyPI (pypi.org)." +category = "dev" optional = false python-versions = "*" files = [ @@ -3542,6 +3746,7 @@ files = [ name = "typing-extensions" version = "4.7.1" description = "Backported and Experimental Type Hints for Python 3.7+" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3553,6 +3758,7 @@ files = [ name = "unittest2" version = "1.1.0" description = "The new features in unittest backported to Python 2.4+." +category = "dev" optional = false python-versions = "*" files = [ @@ -3569,6 +3775,7 @@ traceback2 = "*" name = "urllib3" version = "1.26.18" description = "HTTP library with thread-safe connection pooling, file post, and more." +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" files = [ @@ -3585,6 +3792,7 @@ socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] name = "virtualenv" version = "20.24.2" description = "Virtual Python Environment builder" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3605,6 +3813,7 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess name = "wcwidth" version = "0.2.6" description = "Measures the displayed width of unicode strings in a terminal" +category = "dev" optional = false python-versions = "*" files = [ @@ -3616,6 +3825,7 @@ files = [ name = "webencodings" version = "0.5.1" description = "Character encoding aliases for legacy web content" +category = "dev" optional = false python-versions = "*" files = [ @@ -3627,6 +3837,7 @@ files = [ name = "xattr" version = "0.10.1" description = "Python wrapper for extended filesystem attributes" +category = "dev" optional = false python-versions = "*" files = [ @@ -3711,6 +3922,7 @@ cffi = ">=1.0" name = "zipp" version = "3.16.2" description = "Backport of pathlib-compatible object wrapper for zip files" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -3725,4 +3937,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.8, <3.11" -content-hash = "d2ed74644f3e06a1e22de2a3ae7469cb8935a019442c43ed98dffdc3b6f13e0e" +content-hash = "b5459ff265a54ed9ccfa386f8737d34d3cb996860f26a6119897d255bc14df6c" diff --git a/prebuild.py b/prebuild.py new file mode 100644 index 000000000..26996d2af --- /dev/null +++ b/prebuild.py @@ -0,0 +1,89 @@ +import os +import platform +import multiprocessing +import numpy +import subprocess # nosec +import sys + + +def build_msg_lib() -> bool: + pure_py_env = os.getenv("LAVA_PURE_PYTHON", 0) + system_name = platform.system().lower() + if system_name != "linux": + return False + return int(pure_py_env) == 0 + + +class CMake: + def __init__(self, sourcedir, targetdir): + self.sourcedir = os.path.abspath(sourcedir) + self.targetdir = os.path.abspath(targetdir) + self.env = os.environ.copy() + self.from_poetry = self._check_poetry() + self.from_cd_action = self._check_cd_action() + self.cmake_command = ["poetry", "run", "cmake"] \ + if self.from_cd_action and self.from_poetry else ["cmake"] + self.cmake_args = [] + self.build_args = [] + if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: + self.parallel = multiprocessing.cpu_count() + + def _check_poetry(self): + exec_code = self.env.get('_', "").rsplit('/')[-1] + if exec_code == 'poetry': + return True + return False + + def _check_cd_action(self): + event_name = self.env.get('GITHUB_EVENT_NAME', '') + return event_name == 'workflow_dispatch' + + def _set_cmake_path(self): + # pylint: disable=W0201 + self.temp_path = os.path.join(os.path.abspath(""), "build") + if not os.path.exists(self.temp_path): + os.makedirs(self.temp_path) + + def _set_cmake_args(self): + debug = int(os.environ.get("DEBUG", 0)) + cfg = "Debug" if debug else "Release" + if self.from_cd_action and self.from_poetry: + # noqa: B603 + python_env = subprocess.check_output(["poetry", "env", "info", "-p"]) \ + .decode().strip() + "/bin/python3" # nosec # noqa + numpy_include_dir = subprocess.check_output(["poetry", "run", # nosec # noqa + "python3", "-c", "import numpy; print(numpy.get_include())"]).decode().strip() # nosec # noqa + else: + python_env = sys.executable + numpy_include_dir = numpy.get_include() + self.cmake_args += [ + f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={self.targetdir}", + f"-DPYTHON_EXECUTABLE={python_env}", + f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm + ] + if "CMAKE_ARGS" in os.environ: + self.cmake_args += [item for item in + os.environ["CMAKE_ARGS"].split(" ") if item] + # Set numpy include header to cpplib + self.cmake_args += [ + f"-DNUMPY_INCLUDE_DIRS={numpy_include_dir}"] + + # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level + # across all generators. + if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: + self.build_args += [f"-j{self.parallel}"] + + def run(self): + self._set_cmake_path() + self._set_cmake_args() + subprocess.check_call([*self.cmake_command, self.sourcedir] + self.cmake_args, cwd=self.temp_path, env=self.env) # nosec # noqa + subprocess.check_call([*self.cmake_command, "--build", "."] + self.build_args, cwd=self.temp_path, env=self.env) # nosec # noqa + + +if __name__ == '__main__': + base_runtime_path = "src/lava/magma/runtime/" + sourcedir = f"{base_runtime_path}_c_message_infrastructure" + targetdir = f"{base_runtime_path}message_infrastructure" + if build_msg_lib(): + cmake = CMake(sourcedir, targetdir) + cmake.run() diff --git a/pyproject.toml b/pyproject.toml index 9d5f7f851..787e612b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [build-system] -requires = ["poetry-core>=1.2.0"] +requires = ["poetry-core>=1.2.0", "wheel", "cmake>=3.12", "numpy>=1.22.2", "pybind11[global]>=2.10.1"] build-backend = "poetry.core.masonry.api" [tool.poetry] @@ -9,7 +9,10 @@ packages = [ {include = "lava", from = "src"}, {include = "tests"} ] -include = ["tutorials"] + +include = ["tutorials", + "src/lava/magma/runtime/message_infrastructure/*.so", + "src/lava/magma/runtime/message_infrastructure/install/lib/lib*"] version = "0.8.0.dev0" readme = "README.md" description = "A Software Framework for Neuromorphic Computing" @@ -52,6 +55,7 @@ numpy = "^1.24.4" scipy = "^1.10.1" networkx = "<=2.8.7" asteval = "^0.9.31" +pybind11 = {extras = ["global"], version = "^2.10.1"} scikit-learn = "^1.3.1" [tool.poetry.dev-dependencies] @@ -84,6 +88,12 @@ autopep8 = "^1.6.0" ipykernel = "^6.15.0" nbformat = "^5.3.0" nbconvert = ">=7.2.10, <7.3" +cpplint = "^1.6.0" +psutil = "^5.9.4" + +[tool.poetry.build] +generate-setup-file = false +script = "prebuild.py" [tool.black] line-length = 80 @@ -139,7 +149,7 @@ extended_default_ignore=[] # Fix for bug while using newer flake8 ver. format = "grouped" max_line_length = 80 show_source = true -exclude = ["./docs/"] +exclude = ["./docs/", "./src/lava/magma/runtime/message_infrastructure/install"] [tool.flakeheaven.plugins] flake8-bandit = ["+*", "-S322", "-B101", "-S404", "-S602"] # Enable a plugin, disable specific checks diff --git a/src/lava/magma/compiler/builders/channel_builder.py b/src/lava/magma/compiler/builders/channel_builder.py index b8f72eaa6..04ac971a3 100644 --- a/src/lava/magma/compiler/builders/channel_builder.py +++ b/src/lava/magma/compiler/builders/channel_builder.py @@ -11,15 +11,15 @@ AbstractProcessModel from lava.magma.compiler.builders. \ runtimeservice_builder import RuntimeServiceBuilder -from lava.magma.compiler.channels.interfaces import ( +from lava.magma.runtime.message_infrastructure import ( Channel, - ChannelType, ) from lava.magma.compiler.utils import PortInitializer from lava.magma.runtime.message_infrastructure \ .message_infrastructure_interface import (MessageInfrastructureInterface) -from lava.magma.compiler.channels.watchdog import WatchdogManager, Watchdog - +from lava.magma.runtime.message_infrastructure.watchdog import \ + Watchdog, WatchdogManager +from lava.magma.runtime.message_infrastructure.interfaces import ChannelType if ty.TYPE_CHECKING: from lava.magma.core.process.process import AbstractProcess from lava.magma.runtime.runtime import Runtime @@ -127,29 +127,12 @@ def build( Exception Can't build channel of type specified """ - channel_class = messaging_infrastructure.channel_class( - channel_type=self.channel_type - ) - - # Watchdogs - sq = watchdog_manager.sq - queues = (sq, sq, sq, sq) - port_initializers = (self.src_port_initializer, - self.dst_port_initializer) - (src_send_watchdog, src_join_watchdog, - dst_recv_watchdog, dst_join_watchdog) = \ - self.create_watchdogs(watchdog_manager, queues, port_initializers) - - return channel_class( - messaging_infrastructure, - self.src_port_initializer.name, - self.dst_port_initializer.name, - self.src_port_initializer.shape, - self.src_port_initializer.d_type, - self.src_port_initializer.size, - src_send_watchdog, src_join_watchdog, - dst_recv_watchdog, dst_join_watchdog - ) + return messaging_infrastructure.channel(self.channel_type, + self.src_port_initializer.name, + self.dst_port_initializer.name, + self.src_port_initializer.shape, + self.src_port_initializer.d_type, # noqa: E501 + self.src_port_initializer.size) @dataclass @@ -186,30 +169,14 @@ def build( Exception Can't build channel of type specified """ - channel_class = messaging_infrastructure.channel_class( - channel_type=self.channel_type - ) - - # Watchdogs - lq, sq = watchdog_manager.lq, watchdog_manager.sq - queues = (sq, sq, lq, sq) - port_initializers = (self.port_initializer, - self.port_initializer) - (src_send_watchdog, src_join_watchdog, - dst_recv_watchdog, dst_join_watchdog) = \ - self.create_watchdogs(watchdog_manager, queues, port_initializers) - channel_name: str = self.port_initializer.name - return channel_class( - messaging_infrastructure, - channel_name + "_src", - channel_name + "_dst", - self.port_initializer.shape, - self.port_initializer.d_type, - self.port_initializer.size, - src_send_watchdog, src_join_watchdog, - dst_recv_watchdog, dst_join_watchdog - ) + return messaging_infrastructure.channel(self.channel_type, + channel_name + "_src", + channel_name + "_dst", + self.port_initializer.shape, + self.port_initializer.d_type, + self.port_initializer.size, + sync=True) @dataclass @@ -244,30 +211,14 @@ def build( Exception Can't build channel of type specified """ - channel_class = messaging_infrastructure.channel_class( - channel_type=self.channel_type - ) - - # Watchdogs - lq, sq = watchdog_manager.lq, watchdog_manager.sq - queues = (sq, sq, lq, sq) - port_initializers = (self.port_initializer, - self.port_initializer) - (src_send_watchdog, src_join_watchdog, - dst_recv_watchdog, dst_join_watchdog) = \ - self.create_watchdogs(watchdog_manager, queues, port_initializers) - channel_name: str = self.port_initializer.name - return channel_class( - messaging_infrastructure, - channel_name + "_src", - channel_name + "_dst", - self.port_initializer.shape, - self.port_initializer.d_type, - self.port_initializer.size, - src_send_watchdog, src_join_watchdog, - dst_recv_watchdog, dst_join_watchdog - ) + return messaging_infrastructure.channel(self.channel_type, + channel_name + "_src", + channel_name + "_dst", + self.port_initializer.shape, + self.port_initializer.d_type, + self.port_initializer.size, + sync=True) @dataclass diff --git a/src/lava/magma/compiler/builders/py_builder.py b/src/lava/magma/compiler/builders/py_builder.py index c6d3ab656..785f591cd 100644 --- a/src/lava/magma/compiler/builders/py_builder.py +++ b/src/lava/magma/compiler/builders/py_builder.py @@ -8,13 +8,13 @@ from scipy.sparse import csr_matrix from lava.magma.compiler.builders.interfaces import AbstractProcessBuilder -from lava.magma.compiler.channels.interfaces import AbstractCspPort -from lava.magma.compiler.channels.pypychannel import CspRecvPort, CspSendPort -from lava.magma.compiler.utils import ( - PortInitializer, - VarInitializer, - VarPortInitializer, +from lava.magma.runtime.message_infrastructure import ( + AbstractTransferPort, + RecvPort, + SendPort ) +from lava.magma.compiler.utils import (PortInitializer, VarInitializer, + VarPortInitializer) from lava.magma.core.model.py.model import AbstractPyProcessModel from lava.magma.core.model.py.ports import ( AbstractPyIOPort, @@ -65,10 +65,11 @@ def __init__( self.py_ports: ty.Dict[str, PortInitializer] = {} self.ref_ports: ty.Dict[str, PortInitializer] = {} self.var_ports: ty.Dict[str, VarPortInitializer] = {} - self.csp_ports: ty.Dict[str, ty.List[AbstractCspPort]] = {} - self._csp_port_map: ty.Dict[str, ty.Dict[str, AbstractCspPort]] = {} - self.csp_rs_send_port: ty.Dict[str, CspSendPort] = {} - self.csp_rs_recv_port: ty.Dict[str, CspRecvPort] = {} + self.csp_ports: ty.Dict[str, ty.List[AbstractTransferPort]] = {} + self._csp_port_map: ty.Dict[str, + ty.Dict[str, AbstractTransferPort]] = {} + self.csp_rs_send_port: ty.Dict[str, SendPort] = {} + self.csp_rs_recv_port: ty.Dict[str, RecvPort] = {} self.proc_params = proc_params def check_all_vars_and_ports_set(self): @@ -170,13 +171,13 @@ def set_var_ports(self, var_ports: ty.List[VarPortInitializer]): self._check_not_assigned_yet(self.var_ports, new_ports.keys(), "ports") self.var_ports.update(new_ports) - def set_csp_ports(self, csp_ports: ty.List[AbstractCspPort]): + def set_csp_ports(self, csp_ports: ty.List[AbstractTransferPort]): """Appends the given list of CspPorts to the ProcessModel. Used by the runtime to configure csp ports during initialization (_build_channels). Parameters ---------- - csp_ports : ty.List[AbstractCspPort] + csp_ports : ty.List[AbstractTransferPort] Raises @@ -206,7 +207,8 @@ def set_csp_ports(self, csp_ports: ty.List[AbstractCspPort]): else: self.csp_ports[port_name] = new_ports[port_name] - def add_csp_port_mapping(self, py_port_id: str, csp_port: AbstractCspPort): + def add_csp_port_mapping(self, py_port_id: str, + csp_port: AbstractTransferPort): """Appends a mapping from a PyPort ID to a CSP port. This is used to associate a CSP port in a PyPort with transformation functions that implement the behavior of virtual ports. @@ -224,7 +226,7 @@ def add_csp_port_mapping(self, py_port_id: str, csp_port: AbstractCspPort): {py_port_id: csp_port} ) - def set_rs_csp_ports(self, csp_ports: ty.List[AbstractCspPort]): + def set_rs_csp_ports(self, csp_ports: ty.List[AbstractTransferPort]): """Set RS CSP Ports Parameters @@ -233,9 +235,9 @@ def set_rs_csp_ports(self, csp_ports: ty.List[AbstractCspPort]): """ for port in csp_ports: - if isinstance(port, CspSendPort): + if isinstance(port, SendPort): self.csp_rs_send_port.update({port.name: port}) - if isinstance(port, CspRecvPort): + if isinstance(port, RecvPort): self.csp_rs_recv_port.update({port.name: port}) def _get_lava_type(self, name: str) -> LavaPyType: @@ -315,16 +317,10 @@ def build(self): csp_send = None if name in self.csp_ports: csp_ports = self.csp_ports[name] - csp_recv = ( - csp_ports[0] - if isinstance(csp_ports[0], CspRecvPort) - else csp_ports[1] - ) - csp_send = ( - csp_ports[0] - if isinstance(csp_ports[0], CspSendPort) - else csp_ports[1] - ) + csp_recv = csp_ports[0] if isinstance( + csp_ports[0], RecvPort) else csp_ports[1] + csp_send = csp_ports[0] if isinstance( + csp_ports[0], SendPort) else csp_ports[1] transformer = ( VirtualPortTransformer( @@ -352,16 +348,10 @@ def build(self): csp_send = None if name in self.csp_ports: csp_ports = self.csp_ports[name] - csp_recv = ( - csp_ports[0] - if isinstance(csp_ports[0], CspRecvPort) - else csp_ports[1] - ) - csp_send = ( - csp_ports[0] - if isinstance(csp_ports[0], CspSendPort) - else csp_ports[1] - ) + csp_recv = csp_ports[0] if isinstance( + csp_ports[0], RecvPort) else csp_ports[1] + csp_send = csp_ports[0] if isinstance( + csp_ports[0], SendPort) else csp_ports[1] transformer = ( VirtualPortTransformer( diff --git a/src/lava/magma/compiler/builders/runtimeservice_builder.py b/src/lava/magma/compiler/builders/runtimeservice_builder.py index ff07c2f7d..894724567 100644 --- a/src/lava/magma/compiler/builders/runtimeservice_builder.py +++ b/src/lava/magma/compiler/builders/runtimeservice_builder.py @@ -5,8 +5,12 @@ import logging import typing as ty -from lava.magma.compiler.channels.interfaces import AbstractCspPort -from lava.magma.compiler.channels.pypychannel import CspRecvPort, CspSendPort +from lava.magma.runtime.message_infrastructure import ( + AbstractTransferPort, + RecvPort, + SendPort +) + from lava.magma.core.sync.protocol import AbstractSyncProtocol from lava.magma.runtime.runtime_services.enums import LoihiVersion from lava.magma.runtime.runtime_services.runtime_service import \ @@ -51,10 +55,10 @@ def __init__( self._compile_config = compile_config self._runtime_service_id = runtime_service_id self._model_ids: ty.List[int] = model_ids - self.csp_send_port: ty.Dict[str, CspSendPort] = {} - self.csp_recv_port: ty.Dict[str, CspRecvPort] = {} - self.csp_proc_send_port: ty.Dict[str, CspSendPort] = {} - self.csp_proc_recv_port: ty.Dict[str, CspRecvPort] = {} + self.csp_send_port: ty.Dict[str, SendPort] = {} + self.csp_recv_port: ty.Dict[str, RecvPort] = {} + self.csp_proc_send_port: ty.Dict[str, SendPort] = {} + self.csp_proc_recv_port: ty.Dict[str, RecvPort] = {} self.loihi_version: ty.Type[LoihiVersion] = loihi_version @property @@ -62,7 +66,7 @@ def runtime_service_id(self): """Return runtime service id.""" return self._runtime_service_id - def set_csp_ports(self, csp_ports: ty.List[AbstractCspPort]): + def set_csp_ports(self, csp_ports: ty.List[AbstractTransferPort]): """Set CSP Ports Parameters @@ -71,12 +75,12 @@ def set_csp_ports(self, csp_ports: ty.List[AbstractCspPort]): """ for port in csp_ports: - if isinstance(port, CspSendPort): + if isinstance(port, SendPort): self.csp_send_port.update({port.name: port}) - if isinstance(port, CspRecvPort): + if isinstance(port, RecvPort): self.csp_recv_port.update({port.name: port}) - def set_csp_proc_ports(self, csp_ports: ty.List[AbstractCspPort]): + def set_csp_proc_ports(self, csp_ports: ty.List[AbstractTransferPort]): """Set CSP Process Ports Parameters @@ -85,9 +89,9 @@ def set_csp_proc_ports(self, csp_ports: ty.List[AbstractCspPort]): """ for port in csp_ports: - if isinstance(port, CspSendPort): + if isinstance(port, SendPort): self.csp_proc_send_port.update({port.name: port}) - if isinstance(port, CspRecvPort): + if isinstance(port, RecvPort): self.csp_proc_recv_port.update({port.name: port}) def build(self) -> AbstractRuntimeService: diff --git a/src/lava/magma/compiler/compiler.py b/src/lava/magma/compiler/compiler.py index 1daddb2ad..4c1438e81 100644 --- a/src/lava/magma/compiler/compiler.py +++ b/src/lava/magma/compiler/compiler.py @@ -46,7 +46,6 @@ class AbstractNcProcessModel: from lava.magma.compiler.builders.runtimeservice_builder import \ RuntimeServiceBuilder from lava.magma.compiler.channel_map import ChannelMap, Payload, PortPair -from lava.magma.compiler.channels.interfaces import ChannelType from lava.magma.compiler.compiler_graphs import ProcGroup, ProcGroupDiGraphs from lava.magma.compiler.compiler_utils import split_proc_builders_by_type from lava.magma.compiler.executable import Executable @@ -67,7 +66,10 @@ class AbstractNcProcessModel: from lava.magma.core.sync.protocols.async_protocol import AsyncProtocol from lava.magma.runtime.runtime import Runtime from lava.magma.runtime.runtime_services.enums import LoihiVersion -from lava.magma.compiler.channels.watchdog import WatchdogManagerBuilder +from lava.magma.runtime.message_infrastructure.interfaces import \ + ChannelType +from lava.magma.runtime.message_infrastructure.watchdog import \ + WatchdogManagerBuilder class Compiler: diff --git a/src/lava/magma/compiler/executable.py b/src/lava/magma/compiler/executable.py index 44cbf5316..773320d09 100644 --- a/src/lava/magma/compiler/executable.py +++ b/src/lava/magma/compiler/executable.py @@ -8,7 +8,8 @@ from dataclasses import dataclass from lava.magma.compiler.builders.interfaces import AbstractChannelBuilder -from lava.magma.compiler.channels.watchdog import WatchdogManagerBuilder +from lava.magma.runtime.message_infrastructure.watchdog import \ + WatchdogManagerBuilder from lava.magma.core.sync.domain import SyncDomain if ty.TYPE_CHECKING: @@ -34,7 +35,7 @@ class Executable: # py_builders: ty.Dict[AbstractProcess, NcProcessBuilder] # c_builders: ty.Dict[AbstractProcess, CProcessBuilder] # nc_builders: ty.Dict[AbstractProcess, PyProcessBuilder] - process_list: ty.List[AbstractProcess] # All leaf processes, flat list. + process_list: ty.List[AbstractProcess] proc_builders: ty.Dict[AbstractProcess, 'AbstractProcessBuilder'] channel_builders: ty.List[ChannelBuilderMp] node_configs: ty.List[NodeConfig] @@ -46,5 +47,5 @@ class Executable: watchdog_manager_builder: WatchdogManagerBuilder = None def assign_runtime_to_all_processes(self, runtime): - for p in self.process_list: + for p in self.proc_builders.keys(): p.runtime = runtime diff --git a/src/lava/magma/compiler/subcompilers/channel_builders_factory.py b/src/lava/magma/compiler/subcompilers/channel_builders_factory.py index 6b7133796..99e9b1b5a 100644 --- a/src/lava/magma/compiler/subcompilers/channel_builders_factory.py +++ b/src/lava/magma/compiler/subcompilers/channel_builders_factory.py @@ -9,7 +9,7 @@ ChannelBuilderNx, ChannelBuilderPyNc, ) from lava.magma.compiler.channel_map import PortPair, ChannelMap -from lava.magma.compiler.channels.interfaces import ChannelType +from lava.magma.runtime.message_infrastructure.interfaces import ChannelType from lava.magma.compiler.utils import PortInitializer, LoihiConnectedPortType, \ LoihiConnectedPortEncodingType from lava.magma.compiler.var_model import LoihiAddress diff --git a/src/lava/magma/compiler/utils.py b/src/lava/magma/compiler/utils.py index 7c50124df..32da20242 100644 --- a/src/lava/magma/compiler/utils.py +++ b/src/lava/magma/compiler/utils.py @@ -4,6 +4,7 @@ import functools as ft import typing as ty +import numpy as np from dataclasses import dataclass from enum import IntEnum @@ -32,6 +33,12 @@ class PortInitializer: size: int transform_funcs: ty.Dict[str, ty.List[ft.partial]] = None + @property + def bytes(self) -> int: + data_type = np.int32 if str(self.d_type) == "LavaCDataType.INT32" \ + else self.d_type + return np.prod(self.shape) * np.dtype(data_type).itemsize + # check if can be a subclass of PortInitializer @dataclass @@ -45,6 +52,12 @@ class VarPortInitializer: port_cls: type transform_funcs: ty.Dict[str, ty.List[ft.partial]] = None + @property + def bytes(self) -> int: + data_type = np.int32 if str(self.d_type) == "LavaCDataType.INT32" \ + else self.d_type + return np.prod(self.shape) * np.dtype(data_type).itemsize + @dataclass class LoihiVarInitializer(VarInitializer): diff --git a/src/lava/magma/core/model/interfaces.py b/src/lava/magma/core/model/interfaces.py index bae0367be..cf00a438a 100644 --- a/src/lava/magma/core/model/interfaces.py +++ b/src/lava/magma/core/model/interfaces.py @@ -4,7 +4,7 @@ import typing as ty from abc import ABC, abstractmethod -from lava.magma.compiler.channels.interfaces import AbstractCspPort +from lava.magma.runtime.message_infrastructure import AbstractTransferPort class AbstractPortImplementation(ABC): @@ -25,7 +25,7 @@ def shape(self) -> ty.Tuple[int, ...]: @property @abstractmethod - def csp_ports(self) -> ty.List[AbstractCspPort]: + def csp_ports(self) -> ty.List[AbstractTransferPort]: """Returns all csp ports of the port.""" def start(self): diff --git a/src/lava/magma/core/model/py/model.py b/src/lava/magma/core/model/py/model.py index 40235d4e1..09934eebe 100644 --- a/src/lava/magma/core/model/py/model.py +++ b/src/lava/magma/core/model/py/model.py @@ -4,19 +4,22 @@ import typing as ty from abc import ABC, abstractmethod +# from functools import partial import logging from lava.utils.sparse import find import numpy as np from scipy.sparse import csr_matrix import platform -from lava.magma.compiler.channels.pypychannel import ( - CspSendPort, - CspRecvPort, - CspSelector, -) +from lava.magma.runtime.message_infrastructure import (SendPort, + RecvPort, + SupportTempChannel, + getTempSendPort, + getTempRecvPort, + Selector) from lava.magma.core.model.model import AbstractProcessModel -from lava.magma.core.model.py.ports import AbstractPyPort, PyVarPort, PyOutPort +from lava.magma.core.model.interfaces import AbstractPortImplementation +from lava.magma.core.model.py.ports import PyVarPort, AbstractPyPort from lava.magma.runtime.mgmt_token_enums import ( enum_to_np, enum_equal, @@ -45,17 +48,17 @@ def __init__( ) -> None: super().__init__(proc_params=proc_params, loglevel=loglevel) self.model_id: ty.Optional[int] = None - self.service_to_process: ty.Optional[CspRecvPort] = None - self.process_to_service: ty.Optional[CspSendPort] = None - self.py_ports: ty.List[AbstractPyPort] = [] + self.service_to_process: ty.Optional[RecvPort] = None + self.process_to_service: ty.Optional[SendPort] = None + self.py_ports: ty.List[AbstractPortImplementation] = [] self.var_ports: ty.List[PyVarPort] = [] self.var_id_to_var_map: ty.Dict[int, ty.Any] = {} - self._selector: CspSelector = CspSelector() - self._action: str = "cmd" + self._selector: Selector = Selector() + self._action: str = 'cmd' self._stopped: bool = False - self._channel_actions: ty.List[ - ty.Tuple[ty.Union[CspSendPort, CspRecvPort], ty.Callable] - ] = [] + self._channel_actions: ty.List[ty.Tuple[ty.Union[SendPort, + RecvPort], + ty.Callable]] = [] self._cmd_handlers: ty.Dict[MGMT_COMMAND, ty.Callable] = { MGMT_COMMAND.STOP[0]: self._stop, MGMT_COMMAND.PAUSE[0]: self._pause, @@ -114,30 +117,46 @@ def _get_var(self): var = getattr(self, var_name) # 2. Send Var data - data_port = self.process_to_service - # Header corresponds to number of values - # Data is either send once (for int) or one by one (array) - if isinstance(var, int) or isinstance(var, np.int32): - data_port.send(enum_to_np(1)) - data_port.send(enum_to_np(var)) - elif isinstance(var, np.ndarray): - # FIXME: send a whole vector (also runtime_service.py) - var_iter = np.nditer(var, order="C") - num_items: np.int32 = np.prod(var.shape) - data_port.send(enum_to_np(num_items)) - for value in var_iter: - data_port.send(enum_to_np(value, np.float64)) - elif isinstance(var, csr_matrix): - _, _, values = find(var, explicit_zeros=True) - num_items = var.data.size - data_port.send(enum_to_np(num_items)) - for value in values: - data_port.send(enum_to_np(value, np.float64)) - elif isinstance(var, str): - encoded_str = list(var.encode("ascii")) - data_port.send(enum_to_np(len(encoded_str))) - for ch in encoded_str: - data_port.send(enum_to_np(ch, d_type=np.int32)) + if SupportTempChannel: + addr_path = self.service_to_process.recv() + data_port = getTempSendPort(str(addr_path[0])) + data_port.start() + if isinstance(var, int) or isinstance(var, np.int32): + data_port.send(enum_to_np(var)) + elif isinstance(var, np.ndarray): + # FIXME: send a whole vector (also runtime_service.py) + data_port.send(var) + elif isinstance(var, csr_matrix): + _, _, data = find(var, explicit_zeros=True) + data_port.send(data) + elif isinstance(var, str): + data_port.send(np.array(var, dtype=str)) + data_port.join() + else: + data_port = self.process_to_service + # Header corresponds to number of values + # Data is either send once (for int) or one by one (array) + if isinstance(var, int) or isinstance(var, np.int32): + data_port.send(enum_to_np(1)) + data_port.send(enum_to_np(var)) + elif isinstance(var, np.ndarray): + # FIXME: send a whole vector (also runtime_service.py) + var_iter = np.nditer(var, order='C') + num_items: np.integer = np.prod(var.shape) + data_port.send(enum_to_np(num_items)) + for value in var_iter: + data_port.send(enum_to_np(value, np.float64)) + elif isinstance(var, csr_matrix): + _, _, values = find(var, explicit_zeros=True) + num_items = var.data.size + data_port.send(enum_to_np(num_items)) + for value in values: + data_port.send(enum_to_np(value, np.float64)) + elif isinstance(var, str): + encoded_str = list(var.encode("ascii")) + data_port.send(enum_to_np(len(encoded_str))) + for ch in encoded_str: + data_port.send(enum_to_np(ch, d_type=np.int32)) def _set_var(self): """Handles the set Var command from runtime service.""" @@ -147,55 +166,84 @@ def _set_var(self): var = getattr(self, var_name) # 2. Receive Var data - data_port = self.service_to_process - if isinstance(var, int) or isinstance(var, np.int32): - # First item is number of items (1) - not needed - data_port.recv() - # Data to set - buffer = data_port.recv()[0] - if isinstance(var, int): - setattr(self, var_name, buffer.item()) - else: + if SupportTempChannel: + addr_path, data_port = getTempRecvPort() + data_port.start() + self.process_to_service.send(np.array([addr_path])) + buffer = data_port.recv() + data_port.join() + if isinstance(var, int) or isinstance(var, np.int32): + buffer = buffer[0] + if isinstance(var, int): + setattr(self, var_name, buffer.item()) + else: + setattr(self, var_name, buffer.astype(var.dtype)) + self.process_to_service.send(MGMT_RESPONSE.SET_COMPLETE) + elif isinstance(var, np.ndarray): + var_iter = np.nditer(var, op_flags=['readwrite']) setattr(self, var_name, buffer.astype(var.dtype)) - self.process_to_service.send(MGMT_RESPONSE.SET_COMPLETE) - elif isinstance(var, np.ndarray): - # First item is number of items - num_items = data_port.recv()[0] - var_iter = np.nditer(var, op_flags=["readwrite"]) - # Set data one by one - for i in var_iter: - if num_items == 0: - break - num_items -= 1 - i[...] = data_port.recv()[0] - self.process_to_service.send(MGMT_RESPONSE.SET_COMPLETE) - elif isinstance(var, csr_matrix): - # First item is number of items - num_items = int(data_port.recv()[0]) - - buffer = np.empty(num_items) - # Set data one by one - for i in range(num_items): - buffer[i] = data_port.recv()[0] - dst, src, _ = find(var) - var = csr_matrix((buffer, (dst, src)), var.shape) - setattr(self, var_name, var) - - self.process_to_service.send(MGMT_RESPONSE.SET_COMPLETE) - elif isinstance(var, str): - # First item is number of items - num_items = int(data_port.recv()[0]) - - s = [] - for i in range(num_items): - s.append(int(data_port.recv()[0])) # decode string from ascii - - s = bytes(s).decode("ascii") - setattr(self, var_name, s) - self.process_to_service.send(MGMT_RESPONSE.SET_COMPLETE) + self.process_to_service.send(MGMT_RESPONSE.SET_COMPLETE) + elif isinstance(var, csr_matrix): + dst, src, _ = find(var) + var = csr_matrix((buffer, (dst, src)), var.shape) + setattr(self, var_name, var) + self.process_to_service.send(MGMT_RESPONSE.SET_COMPLETE) + elif isinstance(var, str): + setattr(self, var_name, np.array_str(buffer)) + self.process_to_service.send(MGMT_RESPONSE.SET_COMPLETE) + else: + self.process_to_service.send(MGMT_RESPONSE.ERROR) + raise RuntimeError("Unsupported type") else: - self.process_to_service.send(MGMT_RESPONSE.ERROR) - raise RuntimeError("Unsupported type") + data_port = self.service_to_process + if isinstance(var, int) or isinstance(var, np.int32): + # First item is number of items (1) - not needed + data_port.recv() + # Data to set + buffer = data_port.recv()[0] + if isinstance(var, int): + setattr(self, var_name, buffer.item()) + else: + setattr(self, var_name, buffer.astype(var.dtype)) + self.process_to_service.send(MGMT_RESPONSE.SET_COMPLETE) + elif isinstance(var, np.ndarray): + # First item is number of items + num_items = data_port.recv()[0] + var_iter = np.nditer(var, op_flags=['readwrite']) + # Set data one by one + for i in var_iter: + if num_items == 0: + break + num_items -= 1 + i[...] = data_port.recv()[0] + self.process_to_service.send(MGMT_RESPONSE.SET_COMPLETE) + elif isinstance(var, csr_matrix): + # First item is number of items + num_items = int(data_port.recv()[0]) + + buffer = np.empty(num_items) + # Set data one by one + for i in range(num_items): + buffer[i] = data_port.recv()[0] + dst, src, _ = find(var) + var = csr_matrix((buffer, (dst, src)), var.shape) + setattr(self, var_name, var) + self.process_to_service.send(MGMT_RESPONSE.SET_COMPLETE) + elif isinstance(var, str): + # First item is number of items + num_items = int(data_port.recv()[0]) + + s = [] + for i in range(num_items): + # decode string from ascii + s.append(int(data_port.recv()[0])) + + s = bytes(s).decode("ascii") + setattr(self, var_name, s) + self.process_to_service.send(MGMT_RESPONSE.SET_COMPLETE) + else: + self.process_to_service.send(MGMT_RESPONSE.ERROR) + raise RuntimeError("Unsupported type") # notify PM that Vars have been changed self.on_var_update() @@ -226,12 +274,10 @@ def run(self): f"command: {cmd} " ) except Exception as inst: - # Inform runtime service about termination self.process_to_service.send(MGMT_RESPONSE.ERROR) - self.join() + self.join() # join cause raise error raise inst - else: - # Handle VarPort requests from RefPorts + elif self._action is not None: self._handle_var_port(self._action) self._channel_actions = [(self.service_to_process, lambda: "cmd")] self.add_ports_for_polling() @@ -503,7 +549,7 @@ def add_ports_for_polling(self): ): for var_port in self.var_ports: for csp_port in var_port.csp_ports: - if isinstance(csp_port, CspRecvPort): + if isinstance(csp_port, RecvPort): def func(fvar_port=var_port): return lambda: fvar_port @@ -591,7 +637,6 @@ def check_for_pause_cmd(self) -> bool: cmd = self.service_to_process.peek() if enum_equal(cmd, MGMT_COMMAND.PAUSE): return True - return False def run_async(self): """ diff --git a/src/lava/magma/core/model/py/ports.py b/src/lava/magma/core/model/py/ports.py index 5a9a61c46..28635803e 100644 --- a/src/lava/magma/core/model/py/ports.py +++ b/src/lava/magma/core/model/py/ports.py @@ -8,8 +8,12 @@ import numpy as np -from lava.magma.compiler.channels.interfaces import AbstractCspPort -from lava.magma.compiler.channels.pypychannel import CspSendPort, CspRecvPort +from lava.magma.runtime.message_infrastructure import ( + AbstractTransferPort, + RecvPort, + SendPort +) + from lava.magma.core.model.interfaces import AbstractPortImplementation from lava.magma.core.model.model import AbstractProcessModel from lava.magma.runtime.mgmt_token_enums import enum_to_np, enum_equal @@ -41,7 +45,7 @@ class AbstractPyPort(AbstractPortImplementation): @property @abstractmethod - def csp_ports(self) -> ty.List[AbstractCspPort]: + def csp_ports(self) -> ty.List[AbstractTransferPort]: """ Abstract property to get a list of the corresponding CSP Ports of all connected PyPorts. The CSP Port is the low level interface of the @@ -82,7 +86,7 @@ class AbstractPyIOPort(AbstractPyPort): """ def __init__(self, - csp_ports: ty.List[AbstractCspPort], + csp_ports: ty.List[AbstractTransferPort], process_model: AbstractProcessModel, shape: ty.Tuple[int, ...], d_type: type): @@ -90,7 +94,7 @@ def __init__(self, super().__init__(process_model, shape, d_type) @property - def csp_ports(self) -> ty.List[AbstractCspPort]: + def csp_ports(self) -> ty.List[AbstractTransferPort]: """Property to get the corresponding CSP Ports of all connected PyPorts (csp_ports). The CSP Port is the low level interface of the backend messaging infrastructure which is used to send and receive data. @@ -109,7 +113,7 @@ class AbstractTransformer(ABC): @abstractmethod def transform(self, data: np.ndarray, - csp_port: AbstractCspPort) -> np.ndarray: + csp_port: AbstractTransferPort) -> np.ndarray: """Transforms incoming data in way that is determined by which CSP port the data is received. @@ -132,13 +136,13 @@ class IdentityTransformer(AbstractTransformer): def transform(self, data: np.ndarray, - _: AbstractCspPort) -> np.ndarray: + _: AbstractTransferPort) -> np.ndarray: return data class VirtualPortTransformer(AbstractTransformer): def __init__(self, - csp_ports: ty.Dict[str, AbstractCspPort], + csp_ports: ty.Dict[str, AbstractTransferPort], transform_funcs: ty.Dict[str, ty.List[ft.partial]]): """Transformer that implements the virtual ports on the path to the receiving PyPort. @@ -165,12 +169,11 @@ def __init__(self, def transform(self, data: np.ndarray, - csp_port: AbstractCspPort) -> np.ndarray: + csp_port: AbstractTransferPort) -> np.ndarray: return self._get_transform(csp_port)(data) - def _get_transform(self, - csp_port: AbstractCspPort) -> ty.Callable[[np.ndarray], - np.ndarray]: + def _get_transform(self, csp_port: AbstractTransferPort) \ + -> ty.Callable[[np.ndarray], np.ndarray]: """For a given CSP port, returns a function that applies, in sequence, all the function pointers associated with the incoming virtual ports. @@ -271,7 +274,7 @@ class PyInPort(AbstractPyIOPort): def __init__( self, - csp_ports: ty.List[AbstractCspPort], + csp_ports: ty.List[AbstractTransferPort], process_model: AbstractProcessModel, shape: ty.Tuple[int, ...], d_type: type, @@ -617,21 +620,22 @@ class PyRefPort(AbstractPyPort): def __init__( self, - csp_send_port: ty.Optional[CspSendPort], - csp_recv_port: ty.Optional[CspRecvPort], + csp_send_port: ty.Optional[SendPort], + csp_recv_port: ty.Optional[RecvPort], process_model: AbstractProcessModel, shape: ty.Tuple[int, ...] = tuple(), d_type: type = int, transformer: ty.Optional[ AbstractTransformer] = IdentityTransformer() ): + self._shape = shape self._transformer = transformer self._csp_recv_port = csp_recv_port self._csp_send_port = csp_send_port super().__init__(process_model, shape, d_type) @property - def csp_ports(self) -> ty.List[AbstractCspPort]: + def csp_ports(self) -> ty.List[AbstractTransferPort]: """Property to get the corresponding CSP Ports of all connected PyPorts (csp_ports). The CSP Port is the low level interface of the backend messaging infrastructure which is used to send and receive data. @@ -711,13 +715,16 @@ def read(self) -> np.ndarray: """ if self._csp_send_port and self._csp_recv_port: if not hasattr(self, 'get_header'): - self.get_header = (np.ones(self._csp_send_port.shape) - * VarPortCmd.GET) + # pylint: disable=W0201 + self.get_header = (np.ones(self._csp_send_port.shape, + dtype=self._d_type) + * VarPortCmd.GET.astype(self._d_type)) self._csp_send_port.send(self.get_header) return self._transformer.transform(self._csp_recv_port.recv(), self._csp_recv_port) else: if not hasattr(self, 'get_zeros'): + # pylint: disable=W0201 self.get_zeros = np.zeros(self._shape, self._d_type) return self.get_zeros @@ -732,8 +739,10 @@ def write(self, data: np.ndarray): """ if self._csp_send_port: if not hasattr(self, 'set_header'): - self.set_header = (np.ones(self._csp_send_port.shape) - * VarPortCmd.SET) + # pylint: disable=W0201 + self.set_header = (np.ones(self._csp_send_port.shape, + dtype=data.dtype) + * VarPortCmd.SET.astype(self._d_type)) self._csp_send_port.send(self.set_header) self._csp_send_port.send(data) @@ -849,8 +858,8 @@ class PyVarPort(AbstractPyPort): def __init__(self, var_name: str, - csp_send_port: ty.Optional[CspSendPort], - csp_recv_port: ty.Optional[CspRecvPort], + csp_send_port: ty.Optional[SendPort], + csp_recv_port: ty.Optional[RecvPort], process_model: AbstractProcessModel, shape: ty.Tuple[int, ...] = tuple(), d_type: type = int, @@ -863,7 +872,7 @@ def __init__(self, super().__init__(process_model, shape, d_type) @property - def csp_ports(self) -> ty.List[AbstractCspPort]: + def csp_ports(self) -> ty.List[AbstractTransferPort]: """Property to get the corresponding CSP Ports of all connected PyPorts (csp_ports). The CSP Port is the low level interface of the backend messaging infrastructure which is used to send and receive data. diff --git a/src/lava/magma/core/process/ports/ports.py b/src/lava/magma/core/process/ports/ports.py index be16cf63a..f26d68c4b 100644 --- a/src/lava/magma/core/process/ports/ports.py +++ b/src/lava/magma/core/process/ports/ports.py @@ -689,6 +689,7 @@ def create_implicit_var_port(var: Var) -> ImplicitVarPort: name = str(vp.name) name_suffix = 1 while hasattr(var.process, vp.name): + # pylint: disable=W0201 vp.name = name + "_" + str(name_suffix) name_suffix += 1 setattr(var.process, vp.name, vp) diff --git a/src/lava/magma/core/process/process.py b/src/lava/magma/core/process/process.py index 394b9d0b3..6c8787001 100644 --- a/src/lava/magma/core/process/process.py +++ b/src/lava/magma/core/process/process.py @@ -11,7 +11,8 @@ from lava.magma.core.process.interfaces import \ AbstractProcessMember, IdGeneratorSingleton -from lava.magma.core.process.message_interface_enum import ActorType +from lava.magma.runtime.message_infrastructure.message_interface_enum import \ + ActorType from lava.magma.core.process.ports.ports import \ InPort, OutPort, RefPort, VarPort from lava.magma.core.process.variable import Var diff --git a/src/lava/magma/core/process/variable.py b/src/lava/magma/core/process/variable.py index c2a143ebc..aed95a66d 100644 --- a/src/lava/magma/core/process/variable.py +++ b/src/lava/magma/core/process/variable.py @@ -10,6 +10,7 @@ AbstractProcessMember, IdGeneratorSingleton, ) +from lava.magma.runtime.message_infrastructure import SupportTempChannel class Var(AbstractProcessMember): @@ -138,9 +139,12 @@ def set(self, if self.process.runtime: # encode if var is str if isinstance(value, str): - value = np.array( - list(value.encode("ascii")), dtype=np.int32 - ) + if SupportTempChannel: + value = np.array(value, dtype=str) + else: + value = np.array( + list(value.encode("ascii")), dtype=np.int32 + ) elif isinstance(value, spmatrix): value = value.tocsr() init_dst, init_src, init_val = find(self.init, @@ -154,7 +158,6 @@ def set(self, "elements must stay equal when using" "set on a sparse matrix.") value = val - self.process.runtime.set_var(self.id, value, idx) else: raise ValueError( @@ -171,11 +174,14 @@ def get(self, idx: np.ndarray = None) -> np.ndarray: if self.process and self.process.runtime: buffer = self.process.runtime.get_var(self.id, idx) if isinstance(self.init, str): - # decode if var is string - return bytes(buffer.astype(int).tolist()).decode("ascii") + if SupportTempChannel: + return np.array_str(buffer) + else: + # decode if var is string + return bytes(buffer.astype(int).tolist()). \ + decode("ascii") if isinstance(self.init, csr_matrix): dst, src, _ = find(self.init) - ret = csr_matrix((buffer, (dst, src)), self.init.shape) return ret else: diff --git a/src/lava/magma/runtime/_c_message_infrastructure/.gitignore b/src/lava/magma/runtime/_c_message_infrastructure/.gitignore new file mode 100644 index 000000000..ebdbd2474 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/.gitignore @@ -0,0 +1,46 @@ +# Prerequisites +*.d + +# Compiled Object files +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Fortran module files +*.mod +*.smod + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app + +# vscode +.vscode/* + +# Build +build/ + +# Log +log/ + +# Compiled grpc protos +csrc/channel/grpc/grpcchannel* +# Compiled CycloneDDS proto +csrc/channel/dds/protos/cyclone_dds/DDSMetaData* \ No newline at end of file diff --git a/src/lava/magma/runtime/_c_message_infrastructure/CMakeLists.txt b/src/lava/magma/runtime/_c_message_infrastructure/CMakeLists.txt new file mode 100644 index 000000000..884729c5e --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/CMakeLists.txt @@ -0,0 +1,204 @@ +cmake_minimum_required(VERSION 3.5) +project(message_passing) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_FLAGS_DEBUG "-g -O2") + +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Choose the type of build: Debug, Release" FORCE) +endif(NOT CMAKE_BUILD_TYPE) + +option(GRPC_CHANNEL "Use grpc_channel" OFF) +option(DDS_CHANNEL "Message library supports DDS Channel" OFF) +option(FASTDDS_ENABLE "enable FastDDS" OFF) +option(CycloneDDS_ENABLE "enable CycloneDDS" OFF) + +if(NOT CMAKE_LIBRARY_OUTPUT_DIRECTORY) + set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) +endif() + +if(GRPC_CHANNEL AND CycloneDDS_ENABLE) + message(FATAL_ERROR "Cannot enable GRPC and CycloneDDS together") +endif() + +if(GRPC_CHANNEL) + add_definitions(-DGRPC_CHANNEL) +endif() + +option(ENABLE_MM_PAUSE "Use _mm_pause for sleep." OFF) +if(ENABLE_MM_PAUSE) + add_definitions(-DENABLE_MM_PAUSE) +endif() + +option(PY_WRAPPER "Use pybind11 to wrapper the message infrastructure lib" ON) +if(PY_WRAPPER) + find_package(pybind11 REQUIRED) +endif() + +set(MESSAGE_INFRASTRUCTURE_SRCS + "csrc/core/abstract_actor.cc" + "csrc/core/abstract_port.cc" + "csrc/core/multiprocessing.cc" + "csrc/core/abstract_port_implementation.cc" + "csrc/core/ports.cc" + "csrc/core/channel_factory.cc" + "csrc/core/message_infrastructure_logging.cc" + "csrc/actor/posix_actor.cc" + "csrc/channel/shmem/shm.cc" + "csrc/channel/shmem/shmem_channel.cc" + "csrc/channel/shmem/shmem_port.cc" + "csrc/channel/socket/socket.cc" + "csrc/channel/socket/socket_channel.cc" + "csrc/channel/socket/socket_port.cc") + +if(GRPC_CHANNEL) + set(GRPC_FETCHCONTENT 1) + set(GRPC_TAG v1.49.1) + include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/grpc_common.cmake) + + set(grpc_path "${CMAKE_CURRENT_SOURCE_DIR}/csrc/channel/grpc/") + get_filename_component(grpc_proto "${grpc_path}/protos/grpcchannel.proto" ABSOLUTE) + get_filename_component(grpc_proto_path "${grpc_proto}" PATH) + + set(grpc_proto_srcs "${grpc_path}/grpcchannel.pb.cc") + set(grpc_proto_hdrs "${grpc_path}/grpcchannel.pb.h") + set(grpc_srcs "${grpc_path}/grpcchannel.grpc.pb.cc") + set(grpc_hdrs "${grpc_path}/grpcchannel.grpc.pb.h") + add_custom_command( + OUTPUT "${grpc_proto_srcs}" "${grpc_proto_hdrs}" "${grpc_srcs}" "${grpc_hdrs}" + COMMAND ${_PROTOBUF_PROTOC} + ARGS --grpc_out "${grpc_path}" + --cpp_out "${grpc_path}" + -I "${grpc_proto_path}" + --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" + "${grpc_proto}" + DEPENDS "${grpc_proto}") + + set(GRPC_CHANNEL_SRCS + "csrc/channel/grpc/grpc.cc" + "csrc/channel/grpc/grpc_port.cc" + "csrc/channel/grpc/grpc_channel.cc" + ${grpc_proto_srcs} + ${grpc_srcs}) +endif() + + +if(DDS_CHANNEL) + set(COMMON_DDS_DESTINATION "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/install") + message("DDS Library destination: ${COMMON_DDS_DESTINATION}") + set(COMMON_DDS_INC "${COMMON_DDS_DESTINATION}/include") + set(DDS_CHANNEL_SRCS + "csrc/channel/dds/dds.cc" + "csrc/channel/dds/dds_channel.cc") + + if(FASTDDS_ENABLE AND CycloneDDS_ENABLE) + message(FATAL_ERROR "Cannot enable Backend FASTDDS and CycloneDDS together") + endif() + + if(CycloneDDS_ENABLE) + include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/cyclonedds.cmake) + set(CYCLONE_DDS_INC "${COMMON_DDS_DESTINATION}/include/ddscxx") + add_custom_target(CycloneDDS_metdata ALL + DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/csrc/channel/dds/protos/cyclone_dds/DDSMetaData.cpp") + add_dependencies(CycloneDDS_metdata cyclonedds-cxx) + add_custom_command(OUTPUT "${CMAKE_CURRENT_SOURCE_DIR}/csrc/channel/dds/protos/cyclone_dds/DDSMetaData.cpp" + COMMAND "${COMMON_DDS_DESTINATION}/bin/idlc" -l cxx ${CMAKE_CURRENT_SOURCE_DIR}/csrc/channel/dds/protos/DDSMetaData.idl + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/csrc/channel/dds/protos/cyclone_dds) + + add_definitions(-DCycloneDDS_ENABLE) + set(CycloneDDS_SRC + ${DDS_CHANNEL_SRCS} + "csrc/channel/dds/cyclone_dds.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/csrc/channel/dds/protos/cyclone_dds/DDSMetaData.cpp") + endif() + + if(FASTDDS_ENABLE) + include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/fastdds.cmake) + add_definitions(-DFASTDDS_ENABLE) + set(FASTDDS_SRC + ${DDS_CHANNEL_SRCS} + "csrc/channel/dds/fast_dds.cc" + "csrc/channel/dds/protos/fast_dds/DDSMetaData.cc" + "csrc/channel/dds/protos/fast_dds/DDSMetaDataPubSubTypes.cc") + endif() + + if(FASTDDS_ENABLE OR CycloneDDS_ENABLE) + add_definitions(-DDDS_CHANNEL) + else() + message(FATAL_ERROR "Please enable DDS backend, (FASTDDS_ENABLE or CycloneDDS_ENABLE)") + endif() +else() + set(FASTDDS_ENABLE OFF) + set(CycloneDDS_ENABLE OFF) +endif() + +add_library(message_infrastructure SHARED + ${MESSAGE_INFRASTRUCTURE_SRCS} + $<$:${CycloneDDS_SRC}> + $<$:${FASTDDS_SRC}> + $<$:${GRPC_CHANNEL_SRCS}>) + +set(MSG_LOG_LEVEL err CACHE STRING "Default vaule: err, error log only.") +set(MSG_LOG_FILE_ENABLE 0 CACHE STRING "Default value: 0, print onto console only.") + +target_compile_definitions(message_infrastructure PUBLIC + $<$:MSG_LOG_LEVEL_ALL> + $<$:MSG_LOG_LEVEL_WARN> + $<$:MSG_LOG_LEVEL_DUMP> + $<$:MSG_LOG_LEVEL_INFO> + $<$:MSG_LOG_LEVEL_ERRO> + $<$:MSG_LOG_LEVEL_ALL>) + +target_compile_definitions(message_infrastructure PUBLIC + $<$:MSG_LOG_FILE_ENABLE>) + +target_include_directories(message_infrastructure PUBLIC + ${PROJECT_SOURCE_DIR}/csrc + $ + $) + +target_link_libraries(message_infrastructure + rt) +if(GRPC_CHANNEL) + target_link_libraries(message_infrastructure + ${_REFLECTION} + ${_GRPC_GRPCPP} + ${_PROTOBUF_LIBPROTOBUF}) +endif() + +if(CycloneDDS_ENABLE) + add_dependencies(message_infrastructure cyclonedds-cxx CycloneDDS_metdata) + target_link_libraries(message_infrastructure + ${COMMON_DDS_DESTINATION}/lib/libddsc.so + ${COMMON_DDS_DESTINATION}/lib/libddscxx.so) +endif() + +if(FASTDDS_ENABLE) + add_dependencies(message_infrastructure foonathan_memory fastcdr fastrtps) + target_link_libraries(message_infrastructure + ${COMMON_DDS_DESTINATION}/lib/libfastcdr.so + ${COMMON_DDS_DESTINATION}/lib/libfastrtps.so) +endif() + +if(PY_WRAPPER) + set(PY_WRAPPER_SRCS + "csrc/message_infrastructure_py_wrapper.cc" + "csrc/channel_proxy.cc" + "csrc/port_proxy.cc") + + pybind11_add_module(MessageInfrastructurePywrapper ${PY_WRAPPER_SRCS}) + target_include_directories(MessageInfrastructurePywrapper PUBLIC + ${NUMPY_INCLUDE_DIRS}) + target_link_libraries(MessageInfrastructurePywrapper PRIVATE message_infrastructure) +endif() + +if(CMAKE_BUILD_TYPE STREQUAL "Debug") + message("debug mode and enable cpp unit test") + enable_testing() + add_subdirectory(test) +else() + message("not debug mode and disable cpp unit test") +endif() + +add_subdirectory(examples/c_pingpong) diff --git a/src/lava/magma/runtime/_c_message_infrastructure/README.md b/src/lava/magma/runtime/_c_message_infrastructure/README.md new file mode 100644 index 000000000..83d9f8c77 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/README.md @@ -0,0 +1,83 @@ +# Message Infrastructure Library CPP Implementation for LAVA +version: v0.2.1 +## Introduction +The message infrastructure library is for LAVA to transfer data. The library provides several method to do communication for IPC on single host or across multiple hosts. + +## Build +Assume you are in `/` folder now. +### 1. Set cmake args by env variables to build the message infrastructure library according to your requirements. +```bash +$ export CMAKE_ARGS="..." +``` +#### (1) If you want to use PythonWrapper of the lib, this step could be just ignored as this is the default setting. +#### (2) If you do not need to use PythonWrapper of the lib and just use the lib for CPP, run the command: +```bash +$ export CMAKE_ARGS="-DPY_WRAPPER=OFF" +``` +#### (3) If you want to use GRPC channel, run the command: + +```bash +$ export CMAKE_ARGS="-DGRPC_CHANNEL=ON" +``` + +Note : +- If your env is using http/https proxy, please unable the proxy to use grpc channel.
+You could use the commands in your ternimal, + ```bash + $ unset http_proxy + $ unset https_proxy + ``` +- When you use grpc channel at main and sub processes together, pls refer to [this link](https://github.com/grpc/grpc/blob/master/doc/fork_support.md) to set env. +- There are conflict of `LOCKABLE` definition at CycloneDDS and gRPC, so reject enabling GRPC_CHANNEL and CycloneDDS_ENABLE together. + +#### (4) If you want to enable DDS channel, run the command: +```bash +$ export CMAKE_ARGS="-DDDS_CHANNEL=ON -D_ENABLE=ON" +# [DDS_BACKEND: FASTDDS, CycloneDDS ..., only support FASTDDS now] +# Before build FastDDS, need to install dependences by below command. +# sudo apt-get install libasio-dev libtinyxml2-dev +``` + +#### (5) Build with cpp unit tests + +```bash +$ export CMAKE_ARGS="-DCMAKE_BUILD_TYPE=Debug" +``` + +### 2. Compile library code +- Run the command to build message infrastructure library. + ```bash + $ python3 prebuild.py + ``` +- If you have select to use PythonWrapper, GRPC channel, DDS channel or CPP unit tests, the source code will be compiled together with the message infrastructure library code. +### 3. Add PYTHONPATH +- Add PYTHONPATH into terminal environment. + ```bash +$ export PYTHONPATH=src/:$PYTHONPATH + ``` +## Run Python test +- For example, run the python test for channel usage + ```bash + $ python3 tests/lava/magma/runtime/message_infrastructure/test_channel.py + ``` + - Run all tests + ```bash + # when enable grpc channel, need to add following env: + # export GRPC_ENABLE_FORK_SUPPORT=true + # export GRPC_POLL_STRATEGY=poll + $ pytest tests/lava/magma/runtime/message_infrastructure/ + ``` + +## Run CPP test +- Run all the CPP test for msg lib + ```bash + $ build/test/test_messaging_infrastructure + ``` + +## Install by poetry +Also users could choose to use poetry to enbale the whole environment. +```bash +$ export CMAKE_ARGS="..." +$ poetry install +$ source .venv/bin/activate +``` diff --git a/src/lava/magma/runtime/_c_message_infrastructure/cmake/cyclonedds.cmake b/src/lava/magma/runtime/_c_message_infrastructure/cmake/cyclonedds.cmake new file mode 100644 index 000000000..8face693e --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/cmake/cyclonedds.cmake @@ -0,0 +1,21 @@ +cmake_minimum_required(VERSION 3.14) + +include(ExternalProject) + +ExternalProject_Add( + cyclonedds + GIT_REPOSITORY https://github.com/eclipse-cyclonedds/cyclonedds.git + GIT_TAG 0.10.2 + SOURCE_DIR ${CMAKE_BINARY_DIR}/cyclonedds + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${COMMON_DDS_DESTINATION} +) + +ExternalProject_Add( + cyclonedds-cxx + GIT_REPOSITORY https://github.com/eclipse-cyclonedds/cyclonedds-cxx.git + GIT_TAG 0.10.2 + SOURCE_DIR ${CMAKE_BINARY_DIR}/cyclonedds-cxx + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${COMMON_DDS_DESTINATION} +) + +add_dependencies(cyclonedds-cxx cyclonedds) diff --git a/src/lava/magma/runtime/_c_message_infrastructure/cmake/fastdds.cmake b/src/lava/magma/runtime/_c_message_infrastructure/cmake/fastdds.cmake new file mode 100644 index 000000000..aed976fd6 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/cmake/fastdds.cmake @@ -0,0 +1,23 @@ +cmake_minimum_required(VERSION 3.14) + +include(ExternalProject) +ExternalProject_Add( + foonathan_memory + GIT_REPOSITORY https://github.com/eProsima/foonathan_memory_vendor.git + SOURCE_DIR ${CMAKE_BINARY_DIR}/foonathan_memory + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${COMMON_DDS_DESTINATION} +) + +ExternalProject_Add( + fastcdr + GIT_REPOSITORY https://github.com/eProsima/Fast-CDR.git + SOURCE_DIR ${CMAKE_BINARY_DIR}/fastcdr + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${COMMON_DDS_DESTINATION} +) + +ExternalProject_Add( + fastrtps + GIT_REPOSITORY https://github.com/eProsima/Fast-DDS.git + SOURCE_DIR ${CMAKE_BINARY_DIR}/fastrtps + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${COMMON_DDS_DESTINATION} -DCMAKE_PREFIX_PATH=${COMMON_DDS_DESTINATION} +) diff --git a/src/lava/magma/runtime/_c_message_infrastructure/cmake/grpc_common.cmake b/src/lava/magma/runtime/_c_message_infrastructure/cmake/grpc_common.cmake new file mode 100644 index 000000000..c1d4e1d31 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/cmake/grpc_common.cmake @@ -0,0 +1,123 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cmake build file for C++ route_guide example. +# Assumes protobuf and gRPC have been installed using cmake. +# See cmake_externalproject/CMakeLists.txt for all-in-one cmake build +# that automatically builds all the dependencies before building route_guide. + +cmake_minimum_required(VERSION 3.5.1) + +set (CMAKE_CXX_STANDARD 17) + +if(MSVC) + add_definitions(-D_WIN32_WINNT=0x600) +endif() + +find_package(Threads REQUIRED) + +if(GRPC_AS_SUBMODULE) + # One way to build a projects that uses gRPC is to just include the + # entire gRPC project tree via "add_subdirectory". + # This approach is very simple to use, but the are some potential + # disadvantages: + # * it includes gRPC's CMakeLists.txt directly into your build script + # without and that can make gRPC's internal setting interfere with your + # own build. + # * depending on what's installed on your system, the contents of submodules + # in gRPC's third_party/* might need to be available (and there might be + # additional prerequisites required to build them). Consider using + # the gRPC_*_PROVIDER options to fine-tune the expected behavior. + # + # A more robust approach to add dependency on gRPC is using + # cmake's ExternalProject_Add (see cmake_externalproject/CMakeLists.txt). + + # Include the gRPC's cmake build (normally grpc source code would live + # in a git submodule called "third_party/grpc", but this example lives in + # the same repository as gRPC sources, so we just look a few directories up) + add_subdirectory(../../.. ${CMAKE_CURRENT_BINARY_DIR}/grpc EXCLUDE_FROM_ALL) + message(STATUS "Using gRPC via add_subdirectory.") + + # After using add_subdirectory, we can now use the grpc targets directly from + # this build. + set(_PROTOBUF_LIBPROTOBUF libprotobuf) + set(_REFLECTION grpc++_reflection) + if(CMAKE_CROSSCOMPILING) + find_program(_PROTOBUF_PROTOC protoc) + else() + set(_PROTOBUF_PROTOC $) + endif() + set(_GRPC_GRPCPP grpc++) + if(CMAKE_CROSSCOMPILING) + find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin) + else() + set(_GRPC_CPP_PLUGIN_EXECUTABLE $) + endif() +elseif(GRPC_FETCHCONTENT) + # Another way is to use CMake's FetchContent module to clone gRPC at + # configure time. This makes gRPC's source code available to your project, + # similar to a git submodule. + message(STATUS "Using gRPC via add_subdirectory (FetchContent).") + include(FetchContent) + FetchContent_Declare( + grpc + GIT_REPOSITORY https://github.com/grpc/grpc.git + # when using gRPC, you will actually set this to an existing tag, such as + # v1.25.0, v1.26.0 etc.. + # For the purpose of testing, we override the tag used to the commit + # that's currently under test. + GIT_TAG ${GRPC_TAG}) + FetchContent_MakeAvailable(grpc) + + # Since FetchContent uses add_subdirectory under the hood, we can use + # the grpc targets directly from this build. + set(_PROTOBUF_LIBPROTOBUF libprotobuf) + set(_REFLECTION grpc++_reflection) + set(_PROTOBUF_PROTOC $) + set(_GRPC_GRPCPP grpc++) + if(CMAKE_CROSSCOMPILING) + find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin) + else() + set(_GRPC_CPP_PLUGIN_EXECUTABLE $) + endif() +else() + # This branch assumes that gRPC and all its dependencies are already installed + # on this system, so they can be located by find_package(). + + # Find Protobuf installation + # Looks for protobuf-config.cmake file installed by Protobuf's cmake installation. + set(protobuf_MODULE_COMPATIBLE TRUE) + find_package(Protobuf CONFIG REQUIRED) + message(STATUS "Using protobuf ${Protobuf_VERSION}") + + set(_PROTOBUF_LIBPROTOBUF protobuf::libprotobuf) + set(_REFLECTION gRPC::grpc++_reflection) + if(CMAKE_CROSSCOMPILING) + find_program(_PROTOBUF_PROTOC protoc) + else() + set(_PROTOBUF_PROTOC $) + endif() + + # Find gRPC installation + # Looks for gRPCConfig.cmake file installed by gRPC's cmake installation. + find_package(gRPC CONFIG REQUIRED) + message(STATUS "Using gRPC ${gRPC_VERSION}") + + set(_GRPC_GRPCPP gRPC::grpc++) + if(CMAKE_CROSSCOMPILING) + find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin) + else() + set(_GRPC_CPP_PLUGIN_EXECUTABLE $) + endif() +endif() diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/actor/posix_actor.cc b/src/lava/magma/runtime/_c_message_infrastructure/csrc/actor/posix_actor.cc new file mode 100644 index 000000000..f6db32879 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/actor/posix_actor.cc @@ -0,0 +1,86 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include +#include +#include + +#include +#include +#include + +namespace message_infrastructure { + +int CheckSemaphore(sem_t *sem) { + int sem_val; + sem_getvalue(sem, &sem_val); + if (sem_val < 0) { + LAVA_LOG_ERR("Get the negtive sem value: %d\n", sem_val); + return -1; + } + if (sem_val == 1) { + LAVA_LOG_ERR("There is a semaphere not used\n"); + return 1; + } + + return 0; +} + +int PosixActor::GetPid() { + return pid_; +} + +int PosixActor::Wait() { + int status; + int options = 0; + int ret = waitpid(pid_, &status, options); + + if (ret < 0) { + LAVA_LOG_ERR("Process %d waitpid error\n", pid_); + return -1; + } + + LAVA_DEBUG(LOG_ACTOR, + "current actor status: %d\n", + static_cast(GetStatus())); + // Check the status + return 0; +} +int PosixActor::ForceStop() { + int status; + kill(pid_, SIGTERM); + wait(&status); + if (WIFSIGNALED(status)) { + if (WTERMSIG(status) == SIGTERM) { + LAVA_LOG(LOG_MP, "The Actor child was ended with SIGTERM\n"); + } else { + LAVA_LOG(LOG_MP, "The Actor child was ended with signal %d\n", status); + } + } + SetStatus(ActorStatus::StatusTerminated); + return 0; +} + +ProcessType PosixActor::Create() { + pid_t pid = fork(); + if (pid > 0) { + LAVA_LOG(LOG_MP, "Parent Process, create child process %d\n", pid); + pid_ = pid; + return ProcessType::ParentProcess; + } + + if (pid == 0) { + LogClear(); + LAVA_LOG(LOG_MP, "Child, new process %d\n", getpid()); + pid_ = getpid(); + Run(); + this->~PosixActor(); + exit(0); + } + LAVA_LOG_ERR("Cannot allocate new pid for the process\n"); + return ProcessType::ErrorProcess; +} + +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/actor/posix_actor.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/actor/posix_actor.h new file mode 100644 index 000000000..0688f4073 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/actor/posix_actor.h @@ -0,0 +1,26 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef ACTOR_POSIX_ACTOR_H_ +#define ACTOR_POSIX_ACTOR_H_ + +#include + +namespace message_infrastructure { + +class PosixActor final : public AbstractActor { + public: + using AbstractActor::AbstractActor; + ~PosixActor() override {} + int GetPid(); + int Wait(); + int ForceStop(); + ProcessType Create(); +}; + +using PosixActorPtr = PosixActor *; + +} // namespace message_infrastructure + +#endif // ACTOR_POSIX_ACTOR_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/cyclone_dds.cc b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/cyclone_dds.cc new file mode 100644 index 000000000..9b55f4cf4 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/cyclone_dds.cc @@ -0,0 +1,246 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include + +#include +#include + +namespace message_infrastructure { + +using namespace org::eclipse::cyclonedds; // NOLINT + +void CycloneDDSPubListener::on_offered_incompatible_qos( + dds::pub::DataWriter& writer, + const dds::core::status::OfferedIncompatibleQosStatus& status) { + LAVA_LOG_WARN(LOG_DDS, + "incompatiable qos found, count: %d\n", + status.total_count()); +} + +void CycloneDDSPubListener::on_publication_matched( + dds::pub::DataWriter& writer, + const dds::core::status::PublicationMatchedStatus &info) { + matched_.store(info.current_count()); + if (info.current_count_change() == 1) { + LAVA_LOG(LOG_DDS, + "CycloneDDS DataReader %d matched.\n", + matched_.load()); + } else if (info.current_count_change() == -1) { + LAVA_LOG(LOG_DDS, + "CycloneDDS DataReader unmatched. left:%d\n", + matched_.load()); + } else { + LAVA_LOG_ERR("CycloneDDS Publistener MatchedStatus error\n"); + } +} + +DDSInitErrorType CycloneDDSPublisher::Init() { + LAVA_LOG(LOG_DDS, "publisher init\n"); + LAVA_DEBUG(LOG_DDS, + "Init CycloneDDS Publisher Successfully, topic name: %s\n", + topic_name_.c_str()); + dds_metadata_ = std::make_shared(); + if (dds_transfer_type_ != DDSTransportType::DDSUDPv4) { + LAVA_LOG_WARN(LOG_DDS, "Unsupport Transfer type and will use UDP\n"); + } + participant_ = dds::domain::DomainParticipant(domain::default_id()); + topic_ = dds::topic::Topic(participant_, + topic_name_); + publisher_ = dds::pub::Publisher(participant_); + listener_ = std::make_shared(); + dds::pub::qos::DataWriterQos wqos = publisher_.default_datawriter_qos(); + wqos << dds::core::policy::History::KeepLast(max_samples_) + << dds::core::policy::Reliability::Reliable(dds::core::Duration + ::from_secs(HEARTBEAT_PERIOD_SECONDS)) + << dds::core::policy::Durability::Volatile(); + writer_ = dds::pub::DataWriter( + publisher_, + topic_, + wqos, + listener_.get(), + dds::core::status::StatusMask::all()); + stop_ = false; + return DDSInitErrorType::DDSNOERR; +} + +bool CycloneDDSPublisher::Publish(DataPtr data) { + LAVA_DEBUG(LOG_DDS, + "CycloneDDS publisher start publishing topic name = %s, matched:%d\n", + topic_name_.c_str(), listener_->matched_.load()); + LAVA_DEBUG(LOG_DDS, + "writer_ matched: %d\n", + writer_.publication_matched_status().current_count()); + while (writer_.publication_matched_status().current_count() == 0) { + helper::Sleep(); + } + LAVA_DEBUG(LOG_DDS, "CycloneDDS publisher find matched reader\n"); + MetaData* metadata = reinterpret_cast(data.get()); + dds_metadata_->nd(metadata->nd); + dds_metadata_->type(metadata->type); + dds_metadata_->elsize(metadata->elsize); + dds_metadata_->total_size(metadata->total_size); + + memcpy(&dds_metadata_->dims()[0], metadata->dims, sizeof(metadata->dims)); + memcpy(&dds_metadata_->strides()[0], + metadata->strides, + sizeof(metadata->strides)); + size_t nbytes = metadata->elsize * metadata->total_size; + dds_metadata_->mdata(std::vector( + reinterpret_cast(metadata->mdata), + reinterpret_cast(metadata->mdata) + nbytes)); + LAVA_DEBUG(LOG_DDS, "CycloneDDS publisher copied\n"); + writer_.write(*dds_metadata_.get()); + LAVA_DEBUG(LOG_DDS, "datawriter send the data\n"); + return true; +} + +void CycloneDDSPublisher::Stop() { + LAVA_LOG(LOG_DDS, \ + "Stop CycloneDDS Publisher, topic_name%s, waiting unmatched...\n", + topic_name_.c_str()); + if (stop_) { + return; + } + while (listener_ != nullptr && listener_->matched_.load() > 0) { + helper::Sleep(); + } + if (writer_ != dds::core::null) { + writer_ = dds::core::null; + } + if (publisher_ != dds::core::null) { + publisher_ = dds::core::null; + } + if (topic_ != dds::core::null) { + topic_ = dds::core::null; + } + if (participant_ != dds::core::null) { + participant_ = dds::core::null; + } + stop_ = true; +} +CycloneDDSPublisher::~CycloneDDSPublisher() { + if (!stop_) { + Stop(); + } +} + +void CycloneDDSSubListener::on_subscription_matched( + dds::sub::DataReader &reader, + const dds::core::status::SubscriptionMatchedStatus &info) { + matched_.store(info.current_count()); + if (info.current_count_change() == 1) { + LAVA_LOG(LOG_DDS, + "CycloneDDS DataWriter %d matched.\n", + matched_.load()); + } else if (info.current_count_change() == -1) { + LAVA_LOG(LOG_DDS, + "CycloneDDS DataWriter unmatched. left:%d\n", + matched_.load()); + } else { + LAVA_LOG_ERR("CycloneDDS Sublistener MatchedStatus error\n"); + } +} +DDSInitErrorType CycloneDDSSubscriber::Init() { + LAVA_LOG(LOG_DDS, "subscriber init\n"); + LAVA_DEBUG(LOG_DDS, + "Init CycloneDDS Subscriber, topic name: %s\n", + topic_name_.c_str()); + if (dds_transfer_type_ != DDSTransportType::DDSUDPv4) { + LAVA_LOG_WARN(LOG_DDS, "Unsupport Transfer type and will use UDP\n"); + } + participant_ = dds::domain::DomainParticipant(domain::default_id()); + topic_ = dds::topic::Topic(participant_, + topic_name_); + subscriber_ = dds::sub::Subscriber(participant_); + listener_ = std::make_shared(); + dds::sub::qos::DataReaderQos rqos = subscriber_.default_datareader_qos(); + rqos << dds::core::policy::History::KeepLast(max_samples_) + << dds::core::policy::Reliability::Reliable(dds::core::Duration + ::from_secs(HEARTBEAT_PERIOD_SECONDS)) + << dds::core::policy::Durability::Volatile(); + dds::core::policy::History history; + + reader_ = dds::sub::DataReader( + subscriber_, + topic_, + rqos, + listener_.get(), + dds::core::status::StatusMask::all()); + selector_ = std::make_shared::Selector>(reader_); + selector_->max_samples(1); + stop_ = false; + return DDSInitErrorType::DDSNOERR; +} + +MetaDataPtr CycloneDDSSubscriber::Recv(bool keep) { + LAVA_DEBUG(LOG_DDS, + "CycloneDDS topic name= %s recving...\n", + topic_name_.c_str()); + dds::sub::LoanedSamples samples; + if (keep) { + while ((samples = selector_->read()).length() <= 0) { // Flawfinder: ignore + helper::Sleep(); + } + } else { + while ((samples = selector_->take()).length() <= 0) { // Flawfinder: ignore + helper::Sleep(); + } + } + + if (samples.length() != 1) { + LAVA_LOG_FATAL("Cylones recv %d samples\n", samples.length()); + } + auto iter = samples.begin(); + if (iter->info().valid()) { + MetaDataPtr metadata = std::make_shared(); + auto dds_metadata = iter->data(); + metadata->nd = dds_metadata.nd(); + metadata->type = dds_metadata.type(); + metadata->elsize = dds_metadata.elsize(); + metadata->total_size = dds_metadata.total_size(); + memcpy(metadata->dims, dds_metadata.dims().data(), sizeof(metadata->dims)); + memcpy(metadata->strides, + dds_metadata.strides().data(), + sizeof(metadata->strides)); + int nbytes = metadata->elsize * metadata->total_size; + void *ptr = malloc(nbytes); + memcpy(ptr, dds_metadata.mdata().data(), nbytes); + metadata->mdata = ptr; + LAVA_DEBUG(LOG_DDS, "Data Recieved\n"); + return metadata; + } else { + LAVA_LOG_ERR("Time out and no data received\n"); + } + return nullptr; +} + +bool CycloneDDSSubscriber::Probe() { + return (selector_->read()).length() > 0; // Flawfinder: ignore +} +void CycloneDDSSubscriber::Stop() { + if (stop_) + return; + LAVA_DEBUG(LOG_DDS, + "CycloneDDSSubscriber topic name = %s Stop and release...\n", + topic_name_.c_str()); + if (listener_ != nullptr && reader_ != dds::core::null) { + reader_.~DataReader(); + reader_ = dds::core::null; + } + if (participant_ != dds::core::null) participant_ = dds::core::null; + if (subscriber_ != dds::core::null) subscriber_ = dds::core::null; + if (topic_ != dds::core::null) topic_ = dds::core::null; + stop_ = true; +} + +CycloneDDSSubscriber::~CycloneDDSSubscriber() { + if (!stop_) { + Stop(); + } +} + +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/cyclone_dds.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/cyclone_dds.h new file mode 100644 index 000000000..56e8bbb1f --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/cyclone_dds.h @@ -0,0 +1,108 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef CHANNEL_DDS_CYCLONE_DDS_H_ +#define CHANNEL_DDS_CYCLONE_DDS_H_ + +#include +#include +#include +#include +#include +#include +#include + +namespace message_infrastructure { + +class CycloneDDSPubListener final : public + dds::pub::NoOpDataWriterListener{ + public: + CycloneDDSPubListener() : matched_(0) {} + void on_offered_incompatible_qos( + dds::pub::DataWriter& writer, + const dds::core::status::OfferedIncompatibleQosStatus& status) override; + void on_publication_matched( + dds::pub::DataWriter &writer, + const dds::core::status::PublicationMatchedStatus &info) override; + ~CycloneDDSPubListener() override {} + std::atomic_uint32_t matched_; +}; + +using CycloneDDSPubListenerPtr = std::shared_ptr; + +class CycloneDDSPublisher final : public DDSPublisher { + public: + CycloneDDSPublisher(const std::string &topic_name, + const DDSTransportType &dds_transfer_type, + const size_t &max_sample) : + stop_(true), + topic_name_(topic_name), + dds_transfer_type_(dds_transfer_type), + max_samples_(max_sample) {} + ~CycloneDDSPublisher() override; + DDSInitErrorType Init(); + bool Publish(DataPtr metadata); + void Stop(); // Can Init again + + private: + CycloneDDSPubListenerPtr listener_ = nullptr; + std::shared_ptr dds_metadata_ = nullptr; + dds::domain::DomainParticipant participant_ = dds::core::null; + dds::topic::Topic topic_ = dds::core::null; + dds::pub::Publisher publisher_ = dds::core::null; + dds::pub::DataWriter writer_ = dds::core::null; + + std::string topic_name_; + DDSTransportType dds_transfer_type_; + size_t max_samples_; + + bool stop_; +}; + +class CycloneDDSSubListener final : public + dds::sub::NoOpDataReaderListener{ + public: + CycloneDDSSubListener() : matched_(0) {} + ~CycloneDDSSubListener() {} + void on_subscription_matched( + dds::sub::DataReader &reader, + const dds::core::status::SubscriptionMatchedStatus &info) override; + std::atomic_uint32_t matched_; +}; + +using CycloneDDSSubListenerPtr = std::shared_ptr; + +class CycloneDDSSubscriber final : public DDSSubscriber { + public: + CycloneDDSSubscriber(const std::string &topic_name, + const DDSTransportType &dds_transfer_type, + const size_t &max_sample) : + stop_(true), + topic_name_(topic_name), + dds_transfer_type_(dds_transfer_type), + max_samples_(max_sample) {} + ~CycloneDDSSubscriber() override; + DDSInitErrorType Init(); + void Stop(); + MetaDataPtr Recv(bool keep); + bool Probe(); + + private: + CycloneDDSSubListenerPtr listener_ = nullptr; + dds::domain::DomainParticipant participant_ = dds::core::null; + dds::topic::Topic topic_ = dds::core::null; + dds::sub::Subscriber subscriber_ = dds::core::null; + dds::sub::DataReader reader_ = dds::core::null; + std::shared_ptr::Selector> + selector_ = nullptr; + + std::string topic_name_; + DDSTransportType dds_transfer_type_; + size_t max_samples_; + bool stop_; +}; + +} // namespace message_infrastructure + +#endif // CHANNEL_DDS_CYCLONE_DDS_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/dds.cc b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/dds.cc new file mode 100644 index 000000000..d1b63fe9c --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/dds.cc @@ -0,0 +1,97 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#if defined(FASTDDS_ENABLE) +#include +#endif +#if defined(CycloneDDS_ENABLE) +#include +#endif +#include +#include +#include +#include // NOLINT + +namespace message_infrastructure { +DDSPtr DDSManager::AllocDDS(const std::string &topic_name, + const DDSTransportType &dds_transfer_type, + const DDSBackendType &dds_backend, + const size_t &max_samples) { + std::lock_guard lg(dds_lock_); + + if (dds_topics_.find(topic_name) != dds_topics_.end()) { + LAVA_LOG_ERR("The topic %s has already been used\n", topic_name.c_str()); + return nullptr; + } + dds_topics_.insert(topic_name); + DDSPtr dds = std::make_shared(topic_name, + dds_transfer_type, + dds_backend, + max_samples); + ddss_.push_back(dds); + return dds; +} + +void DDSManager::DeleteAllDDS() { + ddss_.clear(); + dds_topics_.clear(); +} + +DDSManager::~DDSManager() { + DeleteAllDDS(); +} + +void DDS::CreateFastDDSBackend(const std::string &topic_name, + const DDSTransportType &dds_transfer_type, + const size_t &max_samples) { +#if defined(FASTDDS_ENABLE) + LAVA_DEBUG(LOG_DDS, "DDS::CreateFastDDSBackend\n"); + dds_publisher_ = std::make_shared(topic_name, + dds_transfer_type, + max_samples); + dds_subscriber_ = std::make_shared(topic_name, + dds_transfer_type, + max_samples); +#else + LAVA_LOG_FATAL("FastDDS is not enable, exit!\n"); +#endif +} + +void DDS::CreateCycloneDDSBackend(const std::string &topic_name, + const DDSTransportType &dds_transfer_type, + const size_t &max_samples) { +#if defined(CycloneDDS_ENABLE) + dds_publisher_ = std::make_shared(topic_name, + dds_transfer_type, + max_samples); + dds_subscriber_ = std::make_shared(topic_name, + dds_transfer_type, + max_samples); +#else + LAVA_LOG_FATAL("CycloneDDS is not enable, exit!\n"); +#endif +} + +DDS::DDS(const std::string &topic_name, + const DDSTransportType &dds_transfer_type, + const DDSBackendType &dds_backend, + const size_t &max_samples) { + if (dds_backend == DDSBackendType::FASTDDSBackend) { + CreateFastDDSBackend(topic_name, dds_transfer_type, max_samples); + } else if (dds_backend == DDSBackendType::CycloneDDSBackend) { + CreateCycloneDDSBackend(topic_name, dds_transfer_type, max_samples); + } else { + LAVA_LOG_ERR("Not support DDSBackendType provided, %d\n", + static_cast(dds_backend)); + } +} + +DDSManager DDSManager::dds_manager_; + +DDSManager& GetDDSManagerSingleton() { + DDSManager &dds_manager = DDSManager::dds_manager_; + return dds_manager; +} + +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/dds.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/dds.h new file mode 100644 index 000000000..314af882f --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/dds.h @@ -0,0 +1,93 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef CHANNEL_DDS_DDS_H_ +#define CHANNEL_DDS_DDS_H_ + +#include +#include +#include +#include +#include +#include +#include // NOLINT + +namespace message_infrastructure { +class DDSPublisher { + public: + virtual DDSInitErrorType Init() = 0; + virtual bool Publish(DataPtr data) = 0; + virtual void Stop() = 0; + virtual ~DDSPublisher() {} +}; + +// DDSPublisher object needs to be transfered to DDSPort. +// Also need to be handled in DDS class. +// Use std::shared_ptr. +using DDSPublisherPtr = std::shared_ptr; + +class DDSSubscriber { + public: + virtual DDSInitErrorType Init() = 0; + virtual MetaDataPtr Recv(bool keep) = 0; + virtual bool Probe() = 0; + virtual void Stop() = 0; + virtual ~DDSSubscriber() {} +}; + +// DDSSubscriber object needs to be transfered to DDSPort. +// Also need to be handled in DDS class. +// Use std::shared_ptr. +using DDSSubscriberPtr = std::shared_ptr; + +class DDS { + public: + DDS(const std::string &topic_name, + const DDSTransportType &dds_transfer_type, + const DDSBackendType &dds_backend, + const size_t &max_samples); + DDSPublisherPtr dds_publisher_ = nullptr; + DDSSubscriberPtr dds_subscriber_ = nullptr; + + private: + void CreateFastDDSBackend(const std::string &topic_name, + const DDSTransportType &dds_transfer_type, + const size_t &max_samples); + void CreateCycloneDDSBackend(const std::string &topic_name, + const DDSTransportType &dds_transfer_type, + const size_t &max_samples); +}; + +// DDS object needs to be transfered to DDSPort. +// Also need to be handled in DDSManager. +// Use std::shared_ptr. +using DDSPtr = std::shared_ptr; + +class DDSManager { + public: + DDSManager(const DDSManager&) = delete; + DDSManager(DDSManager&&) = delete; + DDSManager& operator=(const DDSManager&) = delete; + DDSManager& operator=(DDSManager&&) = delete; + DDSPtr AllocDDS(const std::string &topic_name, + const DDSTransportType &dds_transfer_type, + const DDSBackendType &dds_backend, + const size_t &max_samples); + void DeleteAllDDS(); + friend DDSManager &GetDDSManagerSingleton(); + + private: + DDSManager() = default; + ~DDSManager(); + std::mutex dds_lock_; + std::vector ddss_; + std::unordered_set dds_topics_; + static DDSManager dds_manager_; +}; + +DDSManager& GetDDSManagerSingleton(); + +} // namespace message_infrastructure + +#endif // CHANNEL_DDS_DDS_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/dds_channel.cc b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/dds_channel.cc new file mode 100644 index 000000000..8856f20d2 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/dds_channel.cc @@ -0,0 +1,56 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include +#include +#include + +namespace message_infrastructure { + +DDSChannel::DDSChannel(const std::string &src_name, + const std::string &dst_name, + const std::string &topic_name, + const size_t &size, + const size_t &nbytes, + const DDSTransportType &dds_transfer_type, + const DDSBackendType &dds_backend) { + LAVA_DEBUG(LOG_DDS, "Creating DDSChannel...\n"); + + dds_ = GetDDSManagerSingleton().AllocDDS( + topic_name, + dds_transfer_type, + dds_backend, + size); + send_port_ = std::make_shared(src_name, size, nbytes, dds_); + recv_port_ = std::make_shared(dst_name, size, nbytes, dds_); +} + +AbstractSendPortPtr DDSChannel::GetSendPort() { + return send_port_; +} + +AbstractRecvPortPtr DDSChannel::GetRecvPort() { + return recv_port_; +} + +std::shared_ptr GetDefaultDDSChannel(const size_t &nbytes, + const size_t &size, + const std::string &src_name, + const std::string &dst_name) { + DDSBackendType BackendType = DDSBackendType::FASTDDSBackend; + #if defined(CycloneDDS_ENABLE) + BackendType = DDSBackendType::CycloneDDSBackend; + #endif + return std::make_shared( + src_name, + dst_name, + "dds_topic_" + std::to_string(std::rand()), + size, + nbytes, + DDSTransportType::DDSUDPv4, + BackendType); +} + +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/dds_channel.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/dds_channel.h new file mode 100644 index 000000000..9680d14b2 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/dds_channel.h @@ -0,0 +1,46 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef CHANNEL_DDS_DDS_CHANNEL_H_ +#define CHANNEL_DDS_DDS_CHANNEL_H_ + +#include +#include +#include +#include +#include + +#include +#include + +namespace message_infrastructure { + +class DDSChannel : public AbstractChannel { + public: + DDSChannel() = delete; + ~DDSChannel() override {} + DDSChannel(const std::string &src_name, + const std::string &dst_name, + const std::string &topic_name, + const size_t &size, + const size_t &nbytes, + const DDSTransportType &dds_transfer_type, + const DDSBackendType &dds_backend); + AbstractSendPortPtr GetSendPort(); + AbstractRecvPortPtr GetRecvPort(); + + private: + DDSPtr dds_ = nullptr; + DDSSendPortPtr send_port_ = nullptr; + DDSRecvPortPtr recv_port_ = nullptr; +}; + +std::shared_ptr GetDefaultDDSChannel(const size_t &nbytes, + const size_t &size, + const std::string &src_name, + const std::string &dst_name); + +} // namespace message_infrastructure + +#endif // CHANNEL_DDS_DDS_CHANNEL_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/dds_port.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/dds_port.h new file mode 100644 index 000000000..c1b7b679f --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/dds_port.h @@ -0,0 +1,86 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef CHANNEL_DDS_DDS_PORT_H_ +#define CHANNEL_DDS_DDS_PORT_H_ + +#include +#include +#include +namespace message_infrastructure { + +class DDSSendPort final : public AbstractSendPort { + public: + DDSSendPort() = delete; + DDSSendPort(const std::string &name, + const size_t &size, + const size_t &nbytes, + DDSPtr dds) :AbstractSendPort(name, size, nbytes), + publisher_(dds->dds_publisher_) {} + ~DDSSendPort() = default; + void Start() { + auto flag = publisher_->Init(); + if (static_cast(flag)) { + LAVA_LOG_FATAL("Publisher Init return error, %d\n", + static_cast(flag)); + } + } + void Send(DataPtr data) { + while (!publisher_->Publish(data)) { + helper::Sleep(); + } + } + void Join() { + publisher_->Stop(); + } + bool Probe() { + return false; + } + + private: + DDSPublisherPtr publisher_; +}; + +// Users should be allowed to copy port objects. +// Use std::shared_ptr. +using DDSSendPortPtr = std::shared_ptr; + +class DDSRecvPort final : public AbstractRecvPort { + public: + DDSRecvPort() = delete; + DDSRecvPort(const std::string &name, + const size_t &size, + const size_t &nbytes, + DDSPtr dds) :AbstractRecvPort(name, size, nbytes), + subscriber_(dds->dds_subscriber_) {} + ~DDSRecvPort() override {} + void Start() { + auto flag = subscriber_->Init(); + if (static_cast(flag)) { + LAVA_LOG_FATAL("Subscriber Init return error, %d\n", + static_cast(flag)); + } + } + MetaDataPtr Recv() { + return subscriber_->Recv(false); + } + void Join() { + subscriber_->Stop(); + } + MetaDataPtr Peek() { + return subscriber_->Recv(true); + } + bool Probe() { + return subscriber_->Probe(); + } + + private: + DDSSubscriberPtr subscriber_; +}; + +using DDSRecvPortPtr = std::shared_ptr; + +} // namespace message_infrastructure + +#endif // CHANNEL_DDS_DDS_PORT_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/fast_dds.cc b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/fast_dds.cc new file mode 100644 index 000000000..216a91fc7 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/fast_dds.cc @@ -0,0 +1,358 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace message_infrastructure { + +using namespace eprosima::fastdds::dds; // NOLINT +using namespace eprosima::fastdds::rtps; // NOLINT +using namespace eprosima::fastrtps::rtps; // NOLINT + +FastDDSPublisher::~FastDDSPublisher() { + LAVA_DEBUG(LOG_DDS, "FastDDS Publisher releasing...\n"); + if (!stop_) { + LAVA_LOG_WARN(LOG_DDS, "Code should Stop Publisher\n"); + Stop(); + } + LAVA_DEBUG(LOG_DDS, "FastDDS Publisher released\n"); +} + +DDSInitErrorType FastDDSPublisher::Init() { + dds_metadata_ = std::make_shared(); + InitParticipant(); + if (participant_ == nullptr) + return DDSInitErrorType::DDSParticipantError; + type_.register_type(participant_); + publisher_ = participant_->create_publisher(PUBLISHER_QOS_DEFAULT); + if (publisher_ == nullptr) + return DDSInitErrorType::DDSPublisherError; + + topic_ = participant_->create_topic(topic_name_, + DDS_DATATYPE_NAME, + TOPIC_QOS_DEFAULT); + if (topic_ == nullptr) + return DDSInitErrorType::DDSTopicError; + + listener_ = std::make_shared(); + InitDataWriter(); + if (writer_ == nullptr) + return DDSInitErrorType::DDSDataWriterError; + + LAVA_DEBUG(LOG_DDS, "Init Fast DDS Publisher Successfully, topic name: %s\n", + topic_name_.c_str()); + stop_ = false; + return DDSInitErrorType::DDSNOERR; +} + +void FastDDSPublisher::InitDataWriter() { + DataWriterQos wqos; + wqos.history().kind = KEEP_ALL_HISTORY_QOS; + wqos.history().depth = max_samples_; + wqos.resource_limits().max_samples = max_samples_; + wqos.resource_limits().allocated_samples = max_samples_ / 2; + wqos.reliable_writer_qos().times + .heartbeatPeriod.seconds = HEARTBEAT_PERIOD_SECONDS; + wqos.reliable_writer_qos().times + .heartbeatPeriod.nanosec = HEARTBEAT_PERIOD_NANOSEC; + wqos.reliability().kind = RELIABLE_RELIABILITY_QOS; + wqos.publish_mode().kind = ASYNCHRONOUS_PUBLISH_MODE; + wqos.endpoint().history_memory_policy = PREALLOCATED_WITH_REALLOC_MEMORY_MODE; + writer_ = publisher_->create_datawriter(topic_, wqos, listener_.get()); +} + +void FastDDSPublisher::InitParticipant() { + DomainParticipantQos pqos; + pqos.transport().use_builtin_transports = false; + pqos.name("Participant pub" + topic_name_); + + auto transport_descriptor = GetTransportDescriptor(dds_transfer_type_); + if (nullptr == transport_descriptor) { + LAVA_LOG_FATAL("Create Transport Fault, exit\n"); + } + pqos.transport().user_transports.push_back(transport_descriptor); + + if (dds_transfer_type_ == DDSTransportType::DDSTCPv4) { + Locator_t initial_peer_locator; + initial_peer_locator.kind = LOCATOR_KIND_TCPv4; + IPLocator::setIPv4(initial_peer_locator, TCPv4_IP); + initial_peer_locator.port = TCP_PORT; + pqos.wire_protocol().builtin.initialPeersList + .push_back(initial_peer_locator); + } + + participant_ = DomainParticipantFactory::get_instance() + ->create_participant(0, pqos); +} + +bool FastDDSPublisher::Publish(DataPtr data) { + LAVA_DEBUG(LOG_DDS, + "FastDDS Publish topic name = %s\n", + topic_name_.c_str()); + MetaData* metadata = reinterpret_cast(data.get()); + if (listener_->matched_ > 0) { + LAVA_DEBUG(LOG_DDS, "FastDDS publisher start publishing...\n"); + dds_metadata_->nd(metadata->nd); + dds_metadata_->type(metadata->type); + dds_metadata_->elsize(metadata->elsize); + dds_metadata_->total_size(metadata->total_size); + memcpy(&dds_metadata_->dims()[0], metadata->dims, sizeof(metadata->dims)); + memcpy(&dds_metadata_->strides()[0], + metadata->strides, + sizeof(metadata->strides)); + size_t nbytes = metadata->elsize * metadata->total_size; + dds_metadata_->mdata(std::vector( + reinterpret_cast(metadata->mdata), + reinterpret_cast(metadata->mdata) + nbytes)); + LAVA_DEBUG(LOG_DDS, "FastDDS publisher copied\n"); + + if (writer_->write(dds_metadata_.get()) != ReturnCode_t::RETCODE_OK) { + LAVA_LOG_WARN(LOG_DDS, "Publisher write return not OK, Why work?\n"); + } else { + LAVA_DEBUG(LOG_DDS, "Publish a data\n"); + } + return true; + } + return false; +} + +void FastDDSPublisher::Stop() { + LAVA_LOG(LOG_DDS, "Stop FastDDS Publisher, waiting unmatched...\n"); + while (listener_ != nullptr && listener_->matched_ > 0) { + helper::Sleep(); + } + if (writer_ != nullptr) { + publisher_->delete_datawriter(writer_); + } + if (publisher_ != nullptr) { + participant_->delete_publisher(publisher_); + } + if (topic_ != nullptr) { + topic_->close(); + participant_->delete_topic(topic_); + } + if (participant_ != nullptr) { + DomainParticipantFactory::get_instance()->delete_participant(participant_); + } + stop_ = true; +} + +void FastDDSPubListener::on_publication_matched( + eprosima::fastdds::dds::DataWriter*, + const eprosima::fastdds::dds::PublicationMatchedStatus& info) { + if (info.current_count_change == 1) { + matched_++; + LAVA_DEBUG(LOG_DDS, "FastDDS DataReader %d matched.\n", matched_); + } else if (info.current_count_change == -1) { + matched_--; + LAVA_DEBUG(LOG_DDS, "FastDDS DataReader unmatched, remain:%d\n", matched_); + } else { + LAVA_LOG_ERR("FastDDS Publistener status error\n"); + } +} + +void FastDDSSubListener::on_subscription_matched( + DataReader*, + const SubscriptionMatchedStatus& info) { + if (info.current_count_change == 1) { + matched_++; + LAVA_DEBUG(LOG_DDS, "FastDDS DataWriter %d matched.\n", matched_); + } else if (info.current_count_change == -1) { + matched_--; + LAVA_DEBUG(LOG_DDS, "FastDDS DataWriter unmatched, remain:%d\n", matched_); + } else { + LAVA_LOG_ERR("Subscriber number is not matched\n"); + } +} + +FastDDSSubscriber::~FastDDSSubscriber() { + LAVA_DEBUG(LOG_DDS, "FastDDS Subscriber Releasing...\n"); + if (!stop_) { + LAVA_LOG_WARN(LOG_DDS, "Code should Stop Subscriber\n"); + Stop(); + } + LAVA_DEBUG(LOG_DDS, "FastDDS Subscriber Released...\n"); +} + +void FastDDSSubscriber::InitParticipant() { + DomainParticipantQos pqos; + pqos.wire_protocol().builtin.discovery_config.discoveryProtocol + = DiscoveryProtocol_t::SIMPLE; + pqos.wire_protocol().builtin.discovery_config. + use_SIMPLE_EndpointDiscoveryProtocol = true; + pqos.wire_protocol().builtin.discovery_config.m_simpleEDP. + use_PublicationReaderANDSubscriptionWriter = true; + pqos.wire_protocol().builtin.discovery_config.m_simpleEDP. + use_PublicationWriterANDSubscriptionReader = true; + pqos.wire_protocol().builtin.discovery_config.leaseDuration + = eprosima::fastrtps::c_TimeInfinite; + pqos.transport().use_builtin_transports = false; + pqos.name("Participant sub" + topic_name_); + + auto transport_descriptor = GetTransportDescriptor(dds_transfer_type_); + if (nullptr == transport_descriptor) { + LAVA_LOG_FATAL("Create Transport Fault, exit\n"); + } + pqos.transport().user_transports.push_back(transport_descriptor); + + participant_ = DomainParticipantFactory::get_instance() + ->create_participant(0, pqos); +} + +void FastDDSSubscriber::InitDataReader() { + DataReaderQos rqos; + rqos.history().kind = KEEP_ALL_HISTORY_QOS; + rqos.history().depth = max_samples_; + rqos.resource_limits().max_samples = max_samples_; + rqos.resource_limits().allocated_samples = max_samples_ / 2; + rqos.reliability().kind = RELIABLE_RELIABILITY_QOS; + rqos.durability().kind = TRANSIENT_LOCAL_DURABILITY_QOS; + rqos.endpoint().history_memory_policy = PREALLOCATED_WITH_REALLOC_MEMORY_MODE; + reader_ = subscriber_->create_datareader(topic_, rqos, listener_.get()); +} + +DDSInitErrorType FastDDSSubscriber::Init() { + InitParticipant(); + if (participant_ == nullptr) + return DDSInitErrorType::DDSParticipantError; + + type_.register_type(participant_); + subscriber_ = participant_->create_subscriber(SUBSCRIBER_QOS_DEFAULT); + if (subscriber_ == nullptr) + return DDSInitErrorType::DDSSubscriberError; + + topic_ = participant_->create_topic(topic_name_, + DDS_DATATYPE_NAME, + TOPIC_QOS_DEFAULT); + if (topic_ == nullptr) + return DDSInitErrorType::DDSTopicError; + + listener_ = std::make_shared(); + InitDataReader(); + if (reader_ == nullptr) + return DDSInitErrorType::DDSDataReaderError; + + LAVA_DEBUG(LOG_DDS, "Init FastDDS Subscriber Successfully, topic name: %s\n", + topic_name_.c_str()); + stop_ = false; + return DDSInitErrorType::DDSNOERR; +} + +MetaDataPtr FastDDSSubscriber::Recv(bool keep) { + LAVA_DEBUG(LOG_DDS, "FastDDS Recv topic name = %s\n", topic_name_.c_str()); + FASTDDS_CONST_SEQUENCE(MDataSeq, ddsmetadata::msg::DDSMetaData); + MDataSeq mdata_seq; + SampleInfoSeq infos; + if (keep) { + LAVA_DEBUG(LOG_DDS, "Keep the data recieved\n"); + while (ReturnCode_t::RETCODE_OK != + reader_->read(mdata_seq, infos, 1)) { // Flawfinder: ignore + helper::Sleep(); + } + } else { + LAVA_DEBUG(LOG_DDS, "Take the data recieved\n"); + + while (ReturnCode_t::RETCODE_OK != + reader_->take(mdata_seq, infos, 1)) { + helper::Sleep(); + } + } + + LAVA_DEBUG(LOG_DDS, "Return the data recieved\n"); + LAVA_DEBUG(LOG_DDS, "INFO length: %d\n", infos.length()); + if (infos[0].valid_data) { + const ddsmetadata::msg::DDSMetaData& dds_metadata = mdata_seq[0]; + MetaDataPtr metadata = std::make_shared(); + metadata->nd = dds_metadata.nd(); + metadata->type = dds_metadata.type(); + metadata->elsize = dds_metadata.elsize(); + metadata->total_size = dds_metadata.total_size(); + memcpy(metadata->dims, dds_metadata.dims().data(), sizeof(metadata->dims)); + memcpy(metadata->strides, + dds_metadata.strides().data(), + sizeof(metadata->strides)); + int nbytes = metadata->elsize * metadata->total_size; + void *ptr = std::calloc(nbytes, 1); + if (ptr == nullptr) { + LAVA_LOG_ERR("alloc failed, errno: %d\n", errno); + } + memcpy(ptr, dds_metadata.mdata().data(), nbytes); + metadata->mdata = ptr; + reader_->return_loan(mdata_seq, infos); + LAVA_DEBUG(LOG_DDS, "Data Recieved\n"); + return metadata; + } else { + LAVA_LOG_WARN(LOG_DDS, "Remote writer die\n"); + } + + LAVA_LOG_ERR("time out and no data received\n"); + return nullptr; +} + +bool FastDDSSubscriber::Probe() { + FASTDDS_CONST_SEQUENCE(MDataSeq, ddsmetadata::msg::DDSMetaData); + MDataSeq mdata_seq; + SampleInfoSeq infos; + bool res = false; + if (ReturnCode_t::RETCODE_OK == + reader_->read(mdata_seq, infos, 1)) { // Flawfinder: ignore + reader_->return_loan(mdata_seq, infos); + res = true; + } + return res; +} + +void FastDDSSubscriber::Stop() { + LAVA_DEBUG(LOG_DDS, "Subscriber Stop and release\n"); + if (reader_ != nullptr) + subscriber_->delete_datareader(reader_); + if (topic_ != nullptr) + participant_->delete_topic(topic_); + if (subscriber_ != nullptr) + participant_->delete_subscriber(subscriber_); + if (participant_ != nullptr) + DomainParticipantFactory::get_instance()->delete_participant(participant_); + stop_ = true; +} + +std::shared_ptr +GetTransportDescriptor(const DDSTransportType &dds_type) { + if (dds_type == DDSTransportType::DDSSHM) { + LAVA_DEBUG(LOG_DDS, "Shared Memory Transport Descriptor\n"); + auto transport = std::make_shared(); + transport->segment_size(SHM_SEGMENT_SIZE); + return transport; + } else if (dds_type == DDSTransportType::DDSTCPv4) { + LAVA_DEBUG(LOG_DDS, "TCPv4 Transport Descriptor\n"); + auto transport = std::make_shared(); + transport->set_WAN_address(TCPv4_IP); + transport->add_listener_port(TCP_PORT); + transport->interfaceWhiteList.push_back(TCPv4_IP); // loopback + return transport; + } else if (dds_type == DDSTransportType::DDSUDPv4) { + LAVA_DEBUG(LOG_DDS, "UDPv4 Transport Descriptor\n"); + auto transport = std::make_shared(); + transport->m_output_udp_socket = UDP_OUT_PORT; + transport->non_blocking_send = NON_BLOCKING_SEND; + return transport; + } else { + LAVA_LOG_ERR("TransportType %d has not supported\n", + static_cast(dds_type)); + } + return nullptr; +} + +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/fast_dds.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/fast_dds.h new file mode 100644 index 000000000..5bd34aadd --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/fast_dds.h @@ -0,0 +1,134 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef CHANNEL_DDS_FAST_DDS_H_ +#define CHANNEL_DDS_FAST_DDS_H_ + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace message_infrastructure { + +class FastDDSPubListener final : public + eprosima::fastdds::dds::DataWriterListener { + public: + FastDDSPubListener() : matched_(0) {} + ~FastDDSPubListener() override {} + void on_publication_matched( + eprosima::fastdds::dds::DataWriter* writer, + const eprosima::fastdds::dds::PublicationMatchedStatus& info) override; + + int matched_; +}; + +// FastDDSPubListener object needs to be transfered to DDSPort. +// Also need to be handled in DDS class. +// Use std::shared_ptr. +using FastDDSPubListenerPtr = std::shared_ptr; + +class FastDDSPublisher final : public DDSPublisher { + public: + FastDDSPublisher(const std::string &topic_name, + const DDSTransportType &dds_transfer_type, + const size_t &max_samples) : + type_(new ddsmetadata::msg::DDSMetaDataPubSubType()), + stop_(true), + topic_name_(topic_name), + dds_transfer_type_(dds_transfer_type), + max_samples_(max_samples) {} + ~FastDDSPublisher() override; + DDSInitErrorType Init(); + bool Publish(DataPtr data); + void Stop(); // Can Init again + + private: + void InitDataWriter(); + void InitParticipant(); + + FastDDSPubListenerPtr listener_ = nullptr; + std::shared_ptr dds_metadata_; + eprosima::fastdds::dds::DomainParticipant* participant_ = nullptr; + eprosima::fastdds::dds::Publisher* publisher_ = nullptr; + eprosima::fastdds::dds::Topic* topic_ = nullptr; + eprosima::fastdds::dds::DataWriter* writer_ = nullptr; + eprosima::fastdds::dds::TypeSupport type_; + + std::string topic_name_; + DDSTransportType dds_transfer_type_; + size_t max_samples_; + + bool stop_; +}; + +class FastDDSSubListener final : public + eprosima::fastdds::dds::DataReaderListener { + public: + FastDDSSubListener() : matched_(0) {} + ~FastDDSSubListener() override {} + void on_data_available( + eprosima::fastdds::dds::DataReader* reader) override {}; + void on_subscription_matched( + eprosima::fastdds::dds::DataReader* reader, + const eprosima::fastdds::dds::SubscriptionMatchedStatus& info) override; + int matched_; +}; + +// FastDDSSubListener object needs to be transfered to DDSPort. +// Also need to be handled in DDS class. +// Use std::shared_ptr. +using FastDDSSubListenerPtr = std::shared_ptr; + +class FastDDSSubscriber final : public DDSSubscriber { + public: + FastDDSSubscriber(const std::string &topic_name, + const DDSTransportType &dds_transfer_type, + const size_t &max_samples) : + type_(new ddsmetadata::msg::DDSMetaDataPubSubType()), + stop_(true), + topic_name_(topic_name), + dds_transfer_type_(dds_transfer_type), + max_samples_(max_samples) {} + ~FastDDSSubscriber() override; + DDSInitErrorType Init(); + void Stop(); + MetaDataPtr Recv(bool keep); + bool Probe(); + + private: + void InitParticipant(); + void InitDataReader(); + FastDDSSubListenerPtr listener_ = nullptr; + eprosima::fastdds::dds::DomainParticipant* participant_ = nullptr; + eprosima::fastdds::dds::Subscriber* subscriber_ = nullptr; + eprosima::fastdds::dds::Topic* topic_ = nullptr; + eprosima::fastdds::dds::DataReader* reader_ = nullptr; + eprosima::fastdds::dds::TypeSupport type_; + + std::string topic_name_; + DDSTransportType dds_transfer_type_; + size_t max_samples_; + bool stop_; +}; + +std::shared_ptr +GetTransportDescriptor(const DDSTransportType &dds_type); +} // namespace message_infrastructure + +#endif // CHANNEL_DDS_FAST_DDS_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/protos/DDSMetaData.idl b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/protos/DDSMetaData.idl new file mode 100644 index 000000000..ceac12244 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/protos/DDSMetaData.idl @@ -0,0 +1,20 @@ +module ddsmetadata { + module msg { + typedef int64 int64__5[5]; + struct DDSMetaData { + int64 nd; + + int64 type; + + int64 elsize; + + int64 total_size; + + int64__5 dims; + + int64__5 strides; + + sequence mdata; + }; + }; +}; diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/protos/cyclone_dds/.gitkeep b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/protos/cyclone_dds/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/protos/fast_dds/DDSMetaData.cc b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/protos/fast_dds/DDSMetaData.cc new file mode 100644 index 000000000..25c64efd4 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/protos/fast_dds/DDSMetaData.cc @@ -0,0 +1,500 @@ +// Copyright 2016 Proyectos y Sistemas de Mantenimiento SL (eProsima). +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/*! + * @file DDSMetaData.cpp + * This source file contains the definition of the described types in the IDL file. + * + * This file was generated by the tool gen. + */ + +#ifdef _WIN32 +// Remove linker warning LNK4221 on Visual Studio +namespace { +char dummy; +} // namespace +#endif // _WIN32 + +#include "DDSMetaData.h" +#include + +#include +using namespace eprosima::fastcdr::exception; + +#include + + +ddsmetadata::msg::DDSMetaData::DDSMetaData() +{ + // m_nd com.eprosima.idl.parser.typecode.PrimitiveTypeCode@442675e1 + m_nd = 0; + // m_type com.eprosima.idl.parser.typecode.PrimitiveTypeCode@6166e06f + m_type = 0; + // m_elsize com.eprosima.idl.parser.typecode.PrimitiveTypeCode@49e202ad + m_elsize = 0; + // m_total_size com.eprosima.idl.parser.typecode.PrimitiveTypeCode@1c72da34 + m_total_size = 0; + // m_dims com.eprosima.idl.parser.typecode.AliasTypeCode@6b0c2d26 + memset(&m_dims, 0, (5) * 8); + // m_strides com.eprosima.idl.parser.typecode.AliasTypeCode@6b0c2d26 + memset(&m_strides, 0, (5) * 8); + // m_mdata com.eprosima.idl.parser.typecode.SequenceTypeCode@3d3fcdb0 + + +} + +ddsmetadata::msg::DDSMetaData::~DDSMetaData() +{ + + + + + + + +} + +ddsmetadata::msg::DDSMetaData::DDSMetaData( + const DDSMetaData& x) +{ + m_nd = x.m_nd; + m_type = x.m_type; + m_elsize = x.m_elsize; + m_total_size = x.m_total_size; + // cppcheck-suppress useInitializationList + m_dims = x.m_dims; + // cppcheck-suppress useInitializationList + m_strides = x.m_strides; + // cppcheck-suppress useInitializationList + m_mdata = x.m_mdata; +} + +ddsmetadata::msg::DDSMetaData::DDSMetaData( + DDSMetaData&& x) +{ + m_nd = x.m_nd; + m_type = x.m_type; + m_elsize = x.m_elsize; + m_total_size = x.m_total_size; + // cppcheck-suppress useInitializationList + m_dims = std::move(x.m_dims); + m_strides = std::move(x.m_strides); + m_mdata = std::move(x.m_mdata); +} + +ddsmetadata::msg::DDSMetaData& ddsmetadata::msg::DDSMetaData::operator =( + const DDSMetaData& x) +{ + + m_nd = x.m_nd; + m_type = x.m_type; + m_elsize = x.m_elsize; + m_total_size = x.m_total_size; + m_dims = x.m_dims; + m_strides = x.m_strides; + m_mdata = x.m_mdata; + + return *this; +} + +ddsmetadata::msg::DDSMetaData& ddsmetadata::msg::DDSMetaData::operator =( + DDSMetaData&& x) +{ + + m_nd = x.m_nd; + m_type = x.m_type; + m_elsize = x.m_elsize; + m_total_size = x.m_total_size; + m_dims = std::move(x.m_dims); + m_strides = std::move(x.m_strides); + m_mdata = std::move(x.m_mdata); + + return *this; +} + +bool ddsmetadata::msg::DDSMetaData::operator ==( + const DDSMetaData& x) const +{ + + return (m_nd == x.m_nd && m_type == x.m_type && m_elsize == x.m_elsize && m_total_size == x.m_total_size && m_dims == x.m_dims && m_strides == x.m_strides && m_mdata == x.m_mdata); +} + +bool ddsmetadata::msg::DDSMetaData::operator !=( + const DDSMetaData& x) const +{ + return !(*this == x); +} + +size_t ddsmetadata::msg::DDSMetaData::getMaxCdrSerializedSize( + size_t current_alignment) +{ + size_t initial_alignment = current_alignment; + + + current_alignment += 8 + eprosima::fastcdr::Cdr::alignment(current_alignment, 8); + + + current_alignment += 8 + eprosima::fastcdr::Cdr::alignment(current_alignment, 8); + + + current_alignment += 8 + eprosima::fastcdr::Cdr::alignment(current_alignment, 8); + + + current_alignment += 8 + eprosima::fastcdr::Cdr::alignment(current_alignment, 8); + + + current_alignment += ((5) * 8) + eprosima::fastcdr::Cdr::alignment(current_alignment, 8); + + + current_alignment += ((5) * 8) + eprosima::fastcdr::Cdr::alignment(current_alignment, 8); + + + current_alignment += 4 + eprosima::fastcdr::Cdr::alignment(current_alignment, 4); + + current_alignment += (100 * 1) + eprosima::fastcdr::Cdr::alignment(current_alignment, 1); + + + + + return current_alignment - initial_alignment; +} + +size_t ddsmetadata::msg::DDSMetaData::getCdrSerializedSize( + const ddsmetadata::msg::DDSMetaData& data, + size_t current_alignment) +{ + (void)data; + size_t initial_alignment = current_alignment; + + + current_alignment += 8 + eprosima::fastcdr::Cdr::alignment(current_alignment, 8); + + + current_alignment += 8 + eprosima::fastcdr::Cdr::alignment(current_alignment, 8); + + + current_alignment += 8 + eprosima::fastcdr::Cdr::alignment(current_alignment, 8); + + + current_alignment += 8 + eprosima::fastcdr::Cdr::alignment(current_alignment, 8); + + + if ((5) > 0) + { + current_alignment += ((5) * 8) + eprosima::fastcdr::Cdr::alignment(current_alignment, 8); + } + + if ((5) > 0) + { + current_alignment += ((5) * 8) + eprosima::fastcdr::Cdr::alignment(current_alignment, 8); + } + + current_alignment += 4 + eprosima::fastcdr::Cdr::alignment(current_alignment, 4); + + if (data.mdata().size() > 0) + { + current_alignment += (data.mdata().size() * 1) + eprosima::fastcdr::Cdr::alignment(current_alignment, 1); + } + + + + + return current_alignment - initial_alignment; +} + +void ddsmetadata::msg::DDSMetaData::serialize( + eprosima::fastcdr::Cdr& scdr) const +{ + + scdr << m_nd; + scdr << m_type; + scdr << m_elsize; + scdr << m_total_size; + scdr << m_dims; + + scdr << m_strides; + + scdr << m_mdata; + +} + +void ddsmetadata::msg::DDSMetaData::deserialize( + eprosima::fastcdr::Cdr& dcdr) +{ + + dcdr >> m_nd; + dcdr >> m_type; + dcdr >> m_elsize; + dcdr >> m_total_size; + dcdr >> m_dims; + + dcdr >> m_strides; + + dcdr >> m_mdata; +} + +/*! + * @brief This function sets a value in member nd + * @param _nd New value for member nd + */ +void ddsmetadata::msg::DDSMetaData::nd( + int64_t _nd) +{ + m_nd = _nd; +} + +/*! + * @brief This function returns the value of member nd + * @return Value of member nd + */ +int64_t ddsmetadata::msg::DDSMetaData::nd() const +{ + return m_nd; +} + +/*! + * @brief This function returns a reference to member nd + * @return Reference to member nd + */ +int64_t& ddsmetadata::msg::DDSMetaData::nd() +{ + return m_nd; +} + +/*! + * @brief This function sets a value in member type + * @param _type New value for member type + */ +void ddsmetadata::msg::DDSMetaData::type( + int64_t _type) +{ + m_type = _type; +} + +/*! + * @brief This function returns the value of member type + * @return Value of member type + */ +int64_t ddsmetadata::msg::DDSMetaData::type() const +{ + return m_type; +} + +/*! + * @brief This function returns a reference to member type + * @return Reference to member type + */ +int64_t& ddsmetadata::msg::DDSMetaData::type() +{ + return m_type; +} + +/*! + * @brief This function sets a value in member elsize + * @param _elsize New value for member elsize + */ +void ddsmetadata::msg::DDSMetaData::elsize( + int64_t _elsize) +{ + m_elsize = _elsize; +} + +/*! + * @brief This function returns the value of member elsize + * @return Value of member elsize + */ +int64_t ddsmetadata::msg::DDSMetaData::elsize() const +{ + return m_elsize; +} + +/*! + * @brief This function returns a reference to member elsize + * @return Reference to member elsize + */ +int64_t& ddsmetadata::msg::DDSMetaData::elsize() +{ + return m_elsize; +} + +/*! + * @brief This function sets a value in member total_size + * @param _total_size New value for member total_size + */ +void ddsmetadata::msg::DDSMetaData::total_size( + int64_t _total_size) +{ + m_total_size = _total_size; +} + +/*! + * @brief This function returns the value of member total_size + * @return Value of member total_size + */ +int64_t ddsmetadata::msg::DDSMetaData::total_size() const +{ + return m_total_size; +} + +/*! + * @brief This function returns a reference to member total_size + * @return Reference to member total_size + */ +int64_t& ddsmetadata::msg::DDSMetaData::total_size() +{ + return m_total_size; +} + +/*! + * @brief This function copies the value in member dims + * @param _dims New value to be copied in member dims + */ +void ddsmetadata::msg::DDSMetaData::dims( + const ddsmetadata::msg::int64__5& _dims) +{ + m_dims = _dims; +} + +/*! + * @brief This function moves the value in member dims + * @param _dims New value to be moved in member dims + */ +void ddsmetadata::msg::DDSMetaData::dims( + ddsmetadata::msg::int64__5&& _dims) +{ + m_dims = std::move(_dims); +} + +/*! + * @brief This function returns a constant reference to member dims + * @return Constant reference to member dims + */ +const ddsmetadata::msg::int64__5& ddsmetadata::msg::DDSMetaData::dims() const +{ + return m_dims; +} + +/*! + * @brief This function returns a reference to member dims + * @return Reference to member dims + */ +ddsmetadata::msg::int64__5& ddsmetadata::msg::DDSMetaData::dims() +{ + return m_dims; +} +/*! + * @brief This function copies the value in member strides + * @param _strides New value to be copied in member strides + */ +void ddsmetadata::msg::DDSMetaData::strides( + const ddsmetadata::msg::int64__5& _strides) +{ + m_strides = _strides; +} + +/*! + * @brief This function moves the value in member strides + * @param _strides New value to be moved in member strides + */ +void ddsmetadata::msg::DDSMetaData::strides( + ddsmetadata::msg::int64__5&& _strides) +{ + m_strides = std::move(_strides); +} + +/*! + * @brief This function returns a constant reference to member strides + * @return Constant reference to member strides + */ +const ddsmetadata::msg::int64__5& ddsmetadata::msg::DDSMetaData::strides() const +{ + return m_strides; +} + +/*! + * @brief This function returns a reference to member strides + * @return Reference to member strides + */ +ddsmetadata::msg::int64__5& ddsmetadata::msg::DDSMetaData::strides() +{ + return m_strides; +} +/*! + * @brief This function copies the value in member mdata + * @param _mdata New value to be copied in member mdata + */ +void ddsmetadata::msg::DDSMetaData::mdata( + const std::vector& _mdata) +{ + m_mdata = _mdata; +} + +/*! + * @brief This function moves the value in member mdata + * @param _mdata New value to be moved in member mdata + */ +void ddsmetadata::msg::DDSMetaData::mdata( + std::vector&& _mdata) +{ + m_mdata = std::move(_mdata); +} + +/*! + * @brief This function returns a constant reference to member mdata + * @return Constant reference to member mdata + */ +const std::vector& ddsmetadata::msg::DDSMetaData::mdata() const +{ + return m_mdata; +} + +/*! + * @brief This function returns a reference to member mdata + * @return Reference to member mdata + */ +std::vector& ddsmetadata::msg::DDSMetaData::mdata() +{ + return m_mdata; +} + +size_t ddsmetadata::msg::DDSMetaData::getKeyMaxCdrSerializedSize( + size_t current_alignment) +{ + size_t current_align = current_alignment; + + + + + + + + + + + return current_align; +} + +bool ddsmetadata::msg::DDSMetaData::isKeyDefined() +{ + return false; +} + +void ddsmetadata::msg::DDSMetaData::serializeKey( + eprosima::fastcdr::Cdr& scdr) const +{ + (void) scdr; + +} + + diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/protos/fast_dds/DDSMetaData.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/protos/fast_dds/DDSMetaData.h new file mode 100644 index 000000000..28381e629 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/protos/fast_dds/DDSMetaData.h @@ -0,0 +1,351 @@ +// Copyright 2016 Proyectos y Sistemas de Mantenimiento SL (eProsima). +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/*! + * @file DDSMetaData.h + * This header file contains the declaration of the described types in the IDL file. + * + * This file was generated by the tool gen. + */ + +#ifndef _FAST_DDS_GENERATED_DDSMETADATA_MSG_DDSMETADATA_H_ +#define _FAST_DDS_GENERATED_DDSMETADATA_MSG_DDSMETADATA_H_ + + +#include + +#include +#include +#include +#include +#include +#include + +#if defined(_WIN32) +#if defined(EPROSIMA_USER_DLL_EXPORT) +#define eProsima_user_DllExport __declspec( dllexport ) +#else +#define eProsima_user_DllExport +#endif // EPROSIMA_USER_DLL_EXPORT +#else +#define eProsima_user_DllExport +#endif // _WIN32 + +#if defined(_WIN32) +#if defined(EPROSIMA_USER_DLL_EXPORT) +#if defined(DDSMetaData_SOURCE) +#define DDSMetaData_DllAPI __declspec( dllexport ) +#else +#define DDSMetaData_DllAPI __declspec( dllimport ) +#endif // DDSMetaData_SOURCE +#else +#define DDSMetaData_DllAPI +#endif // EPROSIMA_USER_DLL_EXPORT +#else +#define DDSMetaData_DllAPI +#endif // _WIN32 + +namespace eprosima { +namespace fastcdr { +class Cdr; +} // namespace fastcdr +} // namespace eprosima + + +namespace ddsmetadata { + namespace msg { + typedef std::array int64__5; + /*! + * @brief This class represents the structure DDSMetaData defined by the user in the IDL file. + * @ingroup DDSMETADATA + */ + class DDSMetaData + { + public: + + /*! + * @brief Default constructor. + */ + eProsima_user_DllExport DDSMetaData(); + + /*! + * @brief Default destructor. + */ + eProsima_user_DllExport ~DDSMetaData(); + + /*! + * @brief Copy constructor. + * @param x Reference to the object ddsmetadata::msg::DDSMetaData that will be copied. + */ + eProsima_user_DllExport DDSMetaData( + const DDSMetaData& x); + + /*! + * @brief Move constructor. + * @param x Reference to the object ddsmetadata::msg::DDSMetaData that will be copied. + */ + eProsima_user_DllExport DDSMetaData( + DDSMetaData&& x); + + /*! + * @brief Copy assignment. + * @param x Reference to the object ddsmetadata::msg::DDSMetaData that will be copied. + */ + eProsima_user_DllExport DDSMetaData& operator =( + const DDSMetaData& x); + + /*! + * @brief Move assignment. + * @param x Reference to the object ddsmetadata::msg::DDSMetaData that will be copied. + */ + eProsima_user_DllExport DDSMetaData& operator =( + DDSMetaData&& x); + + /*! + * @brief Comparison operator. + * @param x ddsmetadata::msg::DDSMetaData object to compare. + */ + eProsima_user_DllExport bool operator ==( + const DDSMetaData& x) const; + + /*! + * @brief Comparison operator. + * @param x ddsmetadata::msg::DDSMetaData object to compare. + */ + eProsima_user_DllExport bool operator !=( + const DDSMetaData& x) const; + + /*! + * @brief This function sets a value in member nd + * @param _nd New value for member nd + */ + eProsima_user_DllExport void nd( + int64_t _nd); + + /*! + * @brief This function returns the value of member nd + * @return Value of member nd + */ + eProsima_user_DllExport int64_t nd() const; + + /*! + * @brief This function returns a reference to member nd + * @return Reference to member nd + */ + eProsima_user_DllExport int64_t& nd(); + + /*! + * @brief This function sets a value in member type + * @param _type New value for member type + */ + eProsima_user_DllExport void type( + int64_t _type); + + /*! + * @brief This function returns the value of member type + * @return Value of member type + */ + eProsima_user_DllExport int64_t type() const; + + /*! + * @brief This function returns a reference to member type + * @return Reference to member type + */ + eProsima_user_DllExport int64_t& type(); + + /*! + * @brief This function sets a value in member elsize + * @param _elsize New value for member elsize + */ + eProsima_user_DllExport void elsize( + int64_t _elsize); + + /*! + * @brief This function returns the value of member elsize + * @return Value of member elsize + */ + eProsima_user_DllExport int64_t elsize() const; + + /*! + * @brief This function returns a reference to member elsize + * @return Reference to member elsize + */ + eProsima_user_DllExport int64_t& elsize(); + + /*! + * @brief This function sets a value in member total_size + * @param _total_size New value for member total_size + */ + eProsima_user_DllExport void total_size( + int64_t _total_size); + + /*! + * @brief This function returns the value of member total_size + * @return Value of member total_size + */ + eProsima_user_DllExport int64_t total_size() const; + + /*! + * @brief This function returns a reference to member total_size + * @return Reference to member total_size + */ + eProsima_user_DllExport int64_t& total_size(); + + /*! + * @brief This function copies the value in member dims + * @param _dims New value to be copied in member dims + */ + eProsima_user_DllExport void dims( + const ddsmetadata::msg::int64__5& _dims); + + /*! + * @brief This function moves the value in member dims + * @param _dims New value to be moved in member dims + */ + eProsima_user_DllExport void dims( + ddsmetadata::msg::int64__5&& _dims); + + /*! + * @brief This function returns a constant reference to member dims + * @return Constant reference to member dims + */ + eProsima_user_DllExport const ddsmetadata::msg::int64__5& dims() const; + + /*! + * @brief This function returns a reference to member dims + * @return Reference to member dims + */ + eProsima_user_DllExport ddsmetadata::msg::int64__5& dims(); + /*! + * @brief This function copies the value in member strides + * @param _strides New value to be copied in member strides + */ + eProsima_user_DllExport void strides( + const ddsmetadata::msg::int64__5& _strides); + + /*! + * @brief This function moves the value in member strides + * @param _strides New value to be moved in member strides + */ + eProsima_user_DllExport void strides( + ddsmetadata::msg::int64__5&& _strides); + + /*! + * @brief This function returns a constant reference to member strides + * @return Constant reference to member strides + */ + eProsima_user_DllExport const ddsmetadata::msg::int64__5& strides() const; + + /*! + * @brief This function returns a reference to member strides + * @return Reference to member strides + */ + eProsima_user_DllExport ddsmetadata::msg::int64__5& strides(); + /*! + * @brief This function copies the value in member mdata + * @param _mdata New value to be copied in member mdata + */ + eProsima_user_DllExport void mdata( + const std::vector& _mdata); + + /*! + * @brief This function moves the value in member mdata + * @param _mdata New value to be moved in member mdata + */ + eProsima_user_DllExport void mdata( + std::vector&& _mdata); + + /*! + * @brief This function returns a constant reference to member mdata + * @return Constant reference to member mdata + */ + eProsima_user_DllExport const std::vector& mdata() const; + + /*! + * @brief This function returns a reference to member mdata + * @return Reference to member mdata + */ + eProsima_user_DllExport std::vector& mdata(); + + /*! + * @brief This function returns the maximum serialized size of an object + * depending on the buffer alignment. + * @param current_alignment Buffer alignment. + * @return Maximum serialized size. + */ + eProsima_user_DllExport static size_t getMaxCdrSerializedSize( + size_t current_alignment = 0); + + /*! + * @brief This function returns the serialized size of a data depending on the buffer alignment. + * @param data Data which is calculated its serialized size. + * @param current_alignment Buffer alignment. + * @return Serialized size. + */ + eProsima_user_DllExport static size_t getCdrSerializedSize( + const ddsmetadata::msg::DDSMetaData& data, + size_t current_alignment = 0); + + + /*! + * @brief This function serializes an object using CDR serialization. + * @param cdr CDR serialization object. + */ + eProsima_user_DllExport void serialize( + eprosima::fastcdr::Cdr& cdr) const; + + /*! + * @brief This function deserializes an object using CDR serialization. + * @param cdr CDR serialization object. + */ + eProsima_user_DllExport void deserialize( + eprosima::fastcdr::Cdr& cdr); + + + + /*! + * @brief This function returns the maximum serialized size of the Key of an object + * depending on the buffer alignment. + * @param current_alignment Buffer alignment. + * @return Maximum serialized size. + */ + eProsima_user_DllExport static size_t getKeyMaxCdrSerializedSize( + size_t current_alignment = 0); + + /*! + * @brief This function tells you if the Key has been defined for this type + */ + eProsima_user_DllExport static bool isKeyDefined(); + + /*! + * @brief This function serializes the key members of an object using CDR serialization. + * @param cdr CDR serialization object. + */ + eProsima_user_DllExport void serializeKey( + eprosima::fastcdr::Cdr& cdr) const; + + private: + + int64_t m_nd; + int64_t m_type; + int64_t m_elsize; + int64_t m_total_size; + ddsmetadata::msg::int64__5 m_dims; + ddsmetadata::msg::int64__5 m_strides; + std::vector m_mdata; + }; + } // namespace msg +} // namespace ddsmetadata + +#endif // _FAST_DDS_GENERATED_DDSMETADATA_MSG_DDSMETADATA_H_ \ No newline at end of file diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/protos/fast_dds/DDSMetaDataPubSubTypes.cc b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/protos/fast_dds/DDSMetaDataPubSubTypes.cc new file mode 100644 index 000000000..260895797 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/protos/fast_dds/DDSMetaDataPubSubTypes.cc @@ -0,0 +1,177 @@ +// Copyright 2016 Proyectos y Sistemas de Mantenimiento SL (eProsima). +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/*! + * @file DDSMetaDataPubSubTypes.cpp + * This header file contains the implementation of the serialization functions. + * + * This file was generated by the tool fastcdrgen. + */ + + +#include +#include + +#include "DDSMetaDataPubSubTypes.h" + +using SerializedPayload_t = eprosima::fastrtps::rtps::SerializedPayload_t; +using InstanceHandle_t = eprosima::fastrtps::rtps::InstanceHandle_t; + +namespace ddsmetadata { + namespace msg { + + DDSMetaDataPubSubType::DDSMetaDataPubSubType() + { + setName("ddsmetadata::msg::dds_::DDSMetaData_"); + auto type_size = DDSMetaData::getMaxCdrSerializedSize(); + type_size += eprosima::fastcdr::Cdr::alignment(type_size, 4); /* possible submessage alignment */ + m_typeSize = static_cast(type_size) + 4; /*encapsulation*/ + m_isGetKeyDefined = DDSMetaData::isKeyDefined(); + size_t keyLength = DDSMetaData::getKeyMaxCdrSerializedSize() > 16 ? + DDSMetaData::getKeyMaxCdrSerializedSize() : 16; + m_keyBuffer = reinterpret_cast(malloc(keyLength)); + memset(m_keyBuffer, 0, keyLength); + } + + DDSMetaDataPubSubType::~DDSMetaDataPubSubType() + { + if (m_keyBuffer != nullptr) + { + free(m_keyBuffer); + } + } + + bool DDSMetaDataPubSubType::serialize( + void* data, + SerializedPayload_t* payload) + { + DDSMetaData* p_type = static_cast(data); + + // Object that manages the raw buffer. + eprosima::fastcdr::FastBuffer fastbuffer(reinterpret_cast(payload->data), payload->max_size); + // Object that serializes the data. + eprosima::fastcdr::Cdr ser(fastbuffer, eprosima::fastcdr::Cdr::DEFAULT_ENDIAN, eprosima::fastcdr::Cdr::DDS_CDR); + payload->encapsulation = ser.endianness() == eprosima::fastcdr::Cdr::BIG_ENDIANNESS ? CDR_BE : CDR_LE; + // Serialize encapsulation + ser.serialize_encapsulation(); + + try + { + // Serialize the object. + p_type->serialize(ser); + } + catch (eprosima::fastcdr::exception::NotEnoughMemoryException& /*exception*/) + { + return false; + } + + // Get the serialized length + payload->length = static_cast(ser.getSerializedDataLength()); + return true; + } + + bool DDSMetaDataPubSubType::deserialize( + SerializedPayload_t* payload, + void* data) + { + try + { + //Convert DATA to pointer of your type + DDSMetaData* p_type = static_cast(data); + + // Object that manages the raw buffer. + eprosima::fastcdr::FastBuffer fastbuffer(reinterpret_cast(payload->data), payload->length); + + // Object that deserializes the data. + eprosima::fastcdr::Cdr deser(fastbuffer, eprosima::fastcdr::Cdr::DEFAULT_ENDIAN, eprosima::fastcdr::Cdr::DDS_CDR); + + // Deserialize encapsulation. + deser.read_encapsulation(); + payload->encapsulation = deser.endianness() == eprosima::fastcdr::Cdr::BIG_ENDIANNESS ? CDR_BE : CDR_LE; + + // Deserialize the object. + p_type->deserialize(deser); + } + catch (eprosima::fastcdr::exception::NotEnoughMemoryException& /*exception*/) + { + return false; + } + + return true; + } + + std::function DDSMetaDataPubSubType::getSerializedSizeProvider( + void* data) + { + return [data]() -> uint32_t + { + return static_cast(type::getCdrSerializedSize(*static_cast(data))) + + 4u /*encapsulation*/; + }; + } + + void* DDSMetaDataPubSubType::createData() + { + return reinterpret_cast(new DDSMetaData()); + } + + void DDSMetaDataPubSubType::deleteData( + void* data) + { + delete(reinterpret_cast(data)); + } + + bool DDSMetaDataPubSubType::getKey( + void* data, + InstanceHandle_t* handle, + bool force_md5) + { + if (!m_isGetKeyDefined) + { + return false; + } + + DDSMetaData* p_type = static_cast(data); + + // Object that manages the raw buffer. + eprosima::fastcdr::FastBuffer fastbuffer(reinterpret_cast(m_keyBuffer), + DDSMetaData::getKeyMaxCdrSerializedSize()); + + // Object that serializes the data. + eprosima::fastcdr::Cdr ser(fastbuffer, eprosima::fastcdr::Cdr::BIG_ENDIANNESS); + p_type->serializeKey(ser); + if (force_md5 || DDSMetaData::getKeyMaxCdrSerializedSize() > 16) + { + m_md5.init(); + m_md5.update(m_keyBuffer, static_cast(ser.getSerializedDataLength())); + m_md5.finalize(); + for (uint8_t i = 0; i < 16; ++i) + { + handle->value[i] = m_md5.digest[i]; + } + } + else + { + for (uint8_t i = 0; i < 16; ++i) + { + handle->value[i] = m_keyBuffer[i]; + } + } + return true; + } + + + } //End of namespace msg + +} //End of namespace ddsmetadata diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/protos/fast_dds/DDSMetaDataPubSubTypes.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/protos/fast_dds/DDSMetaDataPubSubTypes.h new file mode 100644 index 000000000..5376a0461 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/dds/protos/fast_dds/DDSMetaDataPubSubTypes.h @@ -0,0 +1,108 @@ +// Copyright 2016 Proyectos y Sistemas de Mantenimiento SL (eProsima). +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/*! + * @file DDSMetaDataPubSubTypes.h + * This header file contains the declaration of the serialization functions. + * + * This file was generated by the tool fastcdrgen. + */ + + +#ifndef _FAST_DDS_GENERATED_DDSMETADATA_MSG_DDSMETADATA_PUBSUBTYPES_H_ +#define _FAST_DDS_GENERATED_DDSMETADATA_MSG_DDSMETADATA_PUBSUBTYPES_H_ + +#include +#include + +#include "DDSMetaData.h" + +#if !defined(GEN_API_VER) || (GEN_API_VER != 1) +#error \ + Generated DDSMetaData is not compatible with current installed Fast DDS. Please, regenerate it with fastddsgen. +#endif // GEN_API_VER + +namespace ddsmetadata +{ + namespace msg + { + typedef std::array int64__5; + /*! + * @brief This class represents the TopicDataType of the type DDSMetaData defined by the user in the IDL file. + * @ingroup DDSMETADATA + */ + class DDSMetaDataPubSubType : public eprosima::fastdds::dds::TopicDataType + { + public: + + typedef DDSMetaData type; + + eProsima_user_DllExport DDSMetaDataPubSubType(); + + eProsima_user_DllExport virtual ~DDSMetaDataPubSubType() override; + + eProsima_user_DllExport virtual bool serialize( + void* data, + eprosima::fastrtps::rtps::SerializedPayload_t* payload) override; + + eProsima_user_DllExport virtual bool deserialize( + eprosima::fastrtps::rtps::SerializedPayload_t* payload, + void* data) override; + + eProsima_user_DllExport virtual std::function getSerializedSizeProvider( + void* data) override; + + eProsima_user_DllExport virtual bool getKey( + void* data, + eprosima::fastrtps::rtps::InstanceHandle_t* ihandle, + bool force_md5 = false) override; + + eProsima_user_DllExport virtual void* createData() override; + + eProsima_user_DllExport virtual void deleteData( + void* data) override; + + #ifdef TOPIC_DATA_TYPE_API_HAS_IS_BOUNDED + eProsima_user_DllExport inline bool is_bounded() const override + { + return false; + } + + #endif // TOPIC_DATA_TYPE_API_HAS_IS_BOUNDED + + #ifdef TOPIC_DATA_TYPE_API_HAS_IS_PLAIN + eProsima_user_DllExport inline bool is_plain() const override + { + return false; + } + + #endif // TOPIC_DATA_TYPE_API_HAS_IS_PLAIN + + #ifdef TOPIC_DATA_TYPE_API_HAS_CONSTRUCT_SAMPLE + eProsima_user_DllExport inline bool construct_sample( + void* memory) const override + { + (void)memory; + return false; + } + + #endif // TOPIC_DATA_TYPE_API_HAS_CONSTRUCT_SAMPLE + + MD5 m_md5; + unsigned char* m_keyBuffer; + }; + } +} + +#endif // _FAST_DDS_GENERATED_DDSMETADATA_MSG_DDSMETADATA_PUBSUBTYPES_H_ \ No newline at end of file diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/grpc/grpc.cc b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/grpc/grpc.cc new file mode 100644 index 000000000..061f0d34c --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/grpc/grpc.cc @@ -0,0 +1,39 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include // NOLINT + +namespace message_infrastructure { + +GrpcManager GrpcManager::grpcm_; +GrpcManager::~GrpcManager() { + url_set_.clear(); +} +GrpcManager& GetGrpcManagerSingleton() { + GrpcManager &grpcm = GrpcManager::grpcm_; + return grpcm; +} +bool GrpcManager::CheckURL(const std::string &url) { + std::lock_guard lg(grpc_lock_); + if (url_set_.find(url) != url_set_.end()) { + return false; + } + url_set_.insert(url); + return true; +} +std::string GrpcManager::AllocURL() { + std::string url; + do { + url = DEFAULT_GRPC_URL + + std::to_string(DEFAULT_GRPC_PORT + port_num_.load()); + port_num_.fetch_add(1); + } while (!CheckURL(url)); + return url; +} +void GrpcManager::Release() { + url_set_.clear(); +} + +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/grpc/grpc.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/grpc/grpc.h new file mode 100644 index 000000000..218214641 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/grpc/grpc.h @@ -0,0 +1,43 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef CHANNEL_GRPC_GRPC_H_ +#define CHANNEL_GRPC_GRPC_H_ + +#include +#include + +#include +#include +#include +#include + +namespace message_infrastructure { + +class GrpcManager { + public: + GrpcManager(const GrpcManager&) = delete; + GrpcManager(GrpcManager&&) = delete; + GrpcManager& operator=(const GrpcManager&) = delete; + GrpcManager& operator=(GrpcManager&&) = delete; + + bool CheckURL(const std::string &url); + std::string AllocURL(); + void Release(); + friend GrpcManager &GetGrpcManagerSingleton(); + + private: + GrpcManager() = default; + ~GrpcManager(); + std::mutex grpc_lock_; + std::atomic port_num_; + static GrpcManager grpcm_; + std::unordered_set url_set_; +}; + +GrpcManager& GetGrpcManagerSingleton(); + +} // namespace message_infrastructure + +#endif // CHANNEL_GRPC_GRPC_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/grpc/grpc_channel.cc b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/grpc/grpc_channel.cc new file mode 100644 index 000000000..857b5b811 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/grpc/grpc_channel.cc @@ -0,0 +1,45 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include +#include +#include + +#include + +namespace message_infrastructure { + +GrpcChannel::GrpcChannel(const std::string &url, + const int &port, + const std::string &src_name, + const std::string &dst_name, + const size_t &size) { + std::string url_ = url + ":" + std::to_string(port); + bool ret = GetGrpcManagerSingleton().CheckURL(url_); + if (!ret) { + LAVA_LOG_ERR("URL is used, Throw an exception\n"); + throw std::invalid_argument(url_ + " is used now!"); + } + send_port_ = std::make_shared(src_name, size, url_); + recv_port_ = std::make_shared(dst_name, size, url_); +} + +GrpcChannel::GrpcChannel(const std::string &src_name, + const std::string &dst_name, + const size_t &size) { + std::string url_ = GetGrpcManagerSingleton().AllocURL(); + send_port_ = std::make_shared(src_name, size, url_); + recv_port_ = std::make_shared(dst_name, size, url_); +} + +AbstractSendPortPtr GrpcChannel::GetSendPort() { + return send_port_; +} + +AbstractRecvPortPtr GrpcChannel::GetRecvPort() { + return recv_port_; +} + +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/grpc/grpc_channel.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/grpc/grpc_channel.h new file mode 100644 index 000000000..2ae631045 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/grpc/grpc_channel.h @@ -0,0 +1,41 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef CHANNEL_GRPC_GRPC_CHANNEL_H_ +#define CHANNEL_GRPC_GRPC_CHANNEL_H_ + +#include +#include + +#include +#include + +namespace message_infrastructure { + +class GrpcChannel : public AbstractChannel { + public: + GrpcChannel() = delete; + GrpcChannel(const std::string &url, + const int &port, + const std::string &src_name, + const std::string &dst_name, + const size_t &size); + GrpcChannel(const std::string &src_name, + const std::string &dst_name, + const size_t &size); + ~GrpcChannel() override {} + AbstractSendPortPtr GetSendPort(); + AbstractRecvPortPtr GetRecvPort(); + private: + GrpcSendPortPtr send_port_ = nullptr; + GrpcRecvPortPtr recv_port_ = nullptr; +}; + +// Users should be allowed to copy channel objects. +// Use std::shared_ptr. +using GrpcChannelPtr = std::shared_ptr; + +} // namespace message_infrastructure + +#endif // CHANNEL_GRPC_GRPC_CHANNEL_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/grpc/grpc_port.cc b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/grpc/grpc_port.cc new file mode 100644 index 000000000..a2055989c --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/grpc/grpc_port.cc @@ -0,0 +1,192 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include +#include + +namespace message_infrastructure { + +namespace { + +MetaDataPtr GrpcMetaData2MetaData(GrpcMetaDataPtr grpcdata) { + MetaDataPtr metadata = std::make_shared(); + metadata->nd = grpcdata->nd(); + metadata->type = grpcdata->type(); + metadata->elsize = grpcdata->elsize(); + metadata->total_size = grpcdata->total_size(); + void* data = std::calloc(metadata->elsize * metadata->total_size, 1); + if (data == nullptr) { + LAVA_LOG_FATAL("Memory alloc failed, errno: %d\n", errno); + } + for (int i = 0; i < metadata->nd; i++) { + metadata->dims[i] = grpcdata->dims(i); + metadata->strides[i] = grpcdata->strides(i); + } + std::memcpy(data, + grpcdata->value().c_str(), + metadata->elsize * metadata->total_size); + metadata->mdata = data; + return metadata; +} + +} // namespace + +template<> +void RecvQueue::FreeData(GrpcMetaDataPtr data) +{} + +GrpcChannelServerImpl::GrpcChannelServerImpl(const std::string& name, + const size_t &size) + :GrpcServerImpl(name, size), done_(false) { + recv_queue_ = std::make_shared>(name_, size_); +} + +Status GrpcChannelServerImpl::RecvArrayData(ServerContext* context, + const GrpcMetaData *request, + DataReply* reply) { + bool rep = true; + while (recv_queue_->AvailableCount() <=0) { + helper::Sleep(); + if (done_) { + // cppcheck-suppress unreadVariable + rep = false; + return Status::OK; + } + } + recv_queue_->Push(std::make_shared(*request)); + reply->set_ack(rep); + return Status::OK; +} + +GrpcMetaDataPtr GrpcChannelServerImpl::Pop(bool block) { + return recv_queue_->Pop(block); +} + +GrpcMetaDataPtr GrpcChannelServerImpl::Front() { + return recv_queue_->Front(); +} + +bool GrpcChannelServerImpl::Probe() { + return recv_queue_->Probe(); +} + +void GrpcChannelServerImpl::Stop() { + done_ = true; + recv_queue_->Stop(); +} + +GrpcChannelBlockServerImpl::GrpcChannelBlockServerImpl(const std::string& name, + const size_t &size) + :GrpcServerImpl(name, size), done_(false), usable_(true) { + block_recv_data_ = nullptr; +} + +Status GrpcChannelBlockServerImpl::RecvArrayData(ServerContext* context, + const GrpcMetaData *request, + DataReply* reply) { + bool rep = true; + while (usable_ != true) { + helper::Sleep(); + if (done_) { + // cppcheck-suppress unreadVariable + rep = false; + return Status::OK; + } + } + block_recv_data_ = std::make_shared(*request); + usable_ = false; + reply->set_ack(rep); + return Status::OK; +} + +GrpcMetaDataPtr GrpcChannelBlockServerImpl::Pop(bool block) { + while (block && usable_ == true) { + helper::Sleep(); + if (done_) + return nullptr; + } + usable_ = true; + return block_recv_data_; +} + +GrpcMetaDataPtr GrpcChannelBlockServerImpl::Front() { + return block_recv_data_; +} + +bool GrpcChannelBlockServerImpl::Probe() { + return block_recv_data_ != nullptr; +} + +void GrpcChannelBlockServerImpl::Stop() { + done_ = true; +} + +GrpcRecvPort::GrpcRecvPort(const std::string& name, + const size_t &size, + const std::string& url): + AbstractRecvPort(name, 1, size), + name_(name), size_(size), done_(false), url_(url) { + if (size_ > 1) { + service_ptr_ = std::make_shared(name_, size_); + } else { + service_ptr_ = std::make_shared(name_, size_); + } +} + +void GrpcRecvPort::Start() { + grpc::EnableDefaultHealthCheckService(true); + grpc::reflection::InitProtoReflectionServerBuilderPlugin(); + builder_.AddListeningPort(url_, grpc::InsecureServerCredentials()); + builder_.RegisterService(service_ptr_.get()); + server_ = builder_.BuildAndStart(); +} + +MetaDataPtr GrpcRecvPort::Recv() { + GrpcMetaDataPtr recv_data = service_ptr_->Pop(true); + return GrpcMetaData2MetaData(recv_data); +} + +MetaDataPtr GrpcRecvPort::Peek() { + GrpcMetaDataPtr peek_data = service_ptr_->Front(); + return GrpcMetaData2MetaData(peek_data); +} + +void GrpcRecvPort::Join() { + if (!done_) { + done_ = true; + service_ptr_->Stop(); + server_->Shutdown(); + } +} + +bool GrpcRecvPort::Probe() { + return service_ptr_->Probe(); +} + +void GrpcSendPort::Start() { + channel_ = grpc::CreateChannel(url_, grpc::InsecureChannelCredentials()); + stub_ = GrpcChannelServer::NewStub(channel_); +} + +void GrpcSendPort::Send(DataPtr grpcdata) { + GrpcMetaData* data = reinterpret_cast(grpcdata.get()); + DataReply reply; + ClientContext context; + context.set_wait_for_ready(true); + Status status = stub_->RecvArrayData(&context, *data, &reply); + if (!reply.ack()) { + LAVA_LOG_ERR("Send fail!\n"); + } +} + +bool GrpcSendPort::Probe() { + return false; +} + +void GrpcSendPort::Join() { + done_ = true; +} + +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/grpc/grpc_port.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/grpc/grpc_port.h new file mode 100644 index 000000000..4261e5c64 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/grpc/grpc_port.h @@ -0,0 +1,163 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef CHANNEL_GRPC_GRPC_PORT_H_ +#define CHANNEL_GRPC_GRPC_PORT_H_ + +#include +#include +#include + +#include +#include +#include +#include + +#include +#include //NOLINT +#include +#include + +namespace message_infrastructure { + +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerContext; +using grpc::Channel; +using grpc::ClientContext; +using grpc::Status; +using grpcchannel::DataReply; +using grpcchannel::GrpcChannelServer; +using grpcchannel::GrpcMetaData; + +using GrpcMetaDataPtr = std::shared_ptr; + +template class RecvQueue; + +inline GrpcMetaDataPtr MetaData2GrpcMetaData(MetaDataPtr metadata) { + GrpcMetaDataPtr grpcdata = std::make_shared(); + grpcdata->set_nd(metadata->nd); + grpcdata->set_type(metadata->type); + grpcdata->set_elsize(metadata->elsize); + grpcdata->set_total_size(metadata->total_size); + // char* data = reinterpret_cast(metadata->mdata); + for (int i = 0; i < metadata->nd; i++) { + grpcdata->add_dims(metadata->dims[i]); + grpcdata->add_strides(metadata->strides[i]); + } + grpcdata->set_value(metadata->mdata, metadata->elsize*metadata->total_size); + return grpcdata; +} +class GrpcServerImpl : public GrpcChannelServer::Service { + public: + GrpcServerImpl(const std::string& name, + const size_t &size): name_(name), size_(size) {} + virtual ~GrpcServerImpl() = default; + virtual Status RecvArrayData(ServerContext* context, + const GrpcMetaData* request, + DataReply* reply) = 0; + virtual GrpcMetaDataPtr Pop(bool block) = 0; + virtual GrpcMetaDataPtr Front() = 0; + virtual bool Probe() = 0; + virtual void Stop() = 0; + protected: + std::string name_; + size_t size_; +}; +class GrpcChannelServerImpl final : public GrpcServerImpl { + public: + GrpcChannelServerImpl(const std::string& name, + const size_t &size); + ~GrpcChannelServerImpl() override {} + Status RecvArrayData(ServerContext* context, + const GrpcMetaData* request, + DataReply* reply) override; + GrpcMetaDataPtr Pop(bool block); + GrpcMetaDataPtr Front(); + bool Probe(); + void Stop(); + + private: + std::shared_ptr> recv_queue_; + std::atomic_bool done_; +}; + +class GrpcChannelBlockServerImpl final : public GrpcServerImpl { + public: + GrpcChannelBlockServerImpl(const std::string& name, + const size_t &size); + ~GrpcChannelBlockServerImpl() override {} + Status RecvArrayData(ServerContext* context, + const GrpcMetaData* request, + DataReply* reply) override; + GrpcMetaDataPtr Pop(bool block); + GrpcMetaDataPtr Front(); + bool Probe(); + void Stop(); + + private: + GrpcMetaDataPtr block_recv_data_; + std::atomic_bool done_; + std::atomic_bool usable_; +}; + +class GrpcRecvPort final : public AbstractRecvPort { + public: + GrpcRecvPort() = delete; + GrpcRecvPort(const std::string& name, + const size_t &size, + const std::string& url); + ~GrpcRecvPort() override {} + void Start(); + MetaDataPtr Recv(); + MetaDataPtr Peek(); + void Join(); + bool Probe(); + + private: + ServerBuilder builder_; + std::atomic_bool done_; + std::unique_ptr server_; + std::shared_ptr service_ptr_; + std::string url_; + std::string name_; + size_t size_; +}; + +// Users should be allowed to copy port objects. +// Use std::shared_ptr. +using GrpcRecvPortPtr = std::shared_ptr; + +class GrpcSendPort final : public AbstractSendPort { + public: + GrpcSendPort() = delete; + GrpcSendPort(const std::string &name, + const size_t &size, + const std::string& url): + AbstractSendPort(name, 1, size), + name_(name), size_(size), done_(false), url_(url) {} + + ~GrpcSendPort() override {} + + void Start(); + void Send(DataPtr grpcdata); + void Join(); + bool Probe(); + + private: + std::shared_ptr channel_; + std::atomic_bool done_; + std::unique_ptr stub_; + std::string url_; + std::string name_; + size_t size_; +}; + +// Users should be allowed to copy port objects. +// Use std::shared_ptr. +using GrpcSendPortPtr = std::shared_ptr; + +} // namespace message_infrastructure + +#endif // CHANNEL_GRPC_GRPC_PORT_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/grpc/protos/grpcchannel.proto b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/grpc/protos/grpcchannel.proto new file mode 100644 index 000000000..39aa26152 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/grpc/protos/grpcchannel.proto @@ -0,0 +1,45 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +option java_multiple_files = true; +option java_package = "io.grpc.examples.grpcchannel"; +option java_outer_classname = "GrpcChannel"; +//option objc_class_prefix = "HLW"; + +package grpcchannel; + +// The request message containing the user's name. +service GrpcChannelServer{ + rpc RecvArrayData (GrpcMetaData) returns (DataReply){} +} + +message GrpcMetaData{ + int64 nd = 1; + int64 type = 2; + int64 elsize = 3; + int64 total_size = 4; + repeated int64 dims = 5; + repeated int64 strides = 6; + bytes value =16; + +} + +// The response message containing the greetings +message DataReply { + bool ack = 1; +} diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/shmem/shm.cc b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/shmem/shm.cc new file mode 100644 index 000000000..0864da7f0 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/shmem/shm.cc @@ -0,0 +1,145 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include + +namespace message_infrastructure { + +SharedMemory::SharedMemory(const size_t &mem_size, + void* mmap, + const int &key) { + data_ = mmap; + size_ = mem_size; + req_name_ += std::to_string(key); + ack_name_ += std::to_string(key); +} + +SharedMemory::SharedMemory(const size_t &mem_size, void* mmap) { + data_ = mmap; + size_ = mem_size; +} + +SharedMemory::~SharedMemory() { +} + +void SharedMemory::InitSemaphore(sem_t *req, sem_t *ack) { + req_ = req; + ack_ = ack; +} + +void SharedMemory::Start() { +} + +void SharedMemory::Store(HandleFn store_fn) { + sem_wait(ack_); + store_fn(data_); + sem_post(req_); +} + +bool SharedMemory::Load(HandleFn consume_fn) { + bool ret = false; + if (!sem_trywait(req_)) { + consume_fn(data_); + sem_post(ack_); + ret = true; + } + return ret; +} + +void SharedMemory::BlockLoad(HandleFn consume_fn) { + sem_wait(req_); + consume_fn(data_); + sem_post(ack_); +} + +void SharedMemory::Read(HandleFn consume_fn) { + sem_wait(req_); + consume_fn(data_); + sem_post(req_); +} + +bool SharedMemory::TryProbe() { + int val; + sem_getvalue(req_, &val); + return val > 0; +} + +void SharedMemory::Close() { + LAVA_ASSERT_INT(sem_close(req_), 0); + LAVA_ASSERT_INT(sem_close(ack_), 0); +} + +std::string SharedMemory::GetReq() { + return req_name_; +} + +std::string SharedMemory::GetAck() { + return ack_name_; +} + +int SharedMemory::GetDataElem(int offset) { + return static_cast(*(reinterpret_cast(data_) + offset)); +} + +RwSharedMemory::RwSharedMemory(const size_t &mem_size, + void* mmap, + const int &key) + : size_(mem_size), data_(mmap) { + sem_name_ += std::to_string(key); +} + +RwSharedMemory::~RwSharedMemory() { + munmap(data_, size_); +} + +void RwSharedMemory::InitSemaphore() { + sem_ = sem_open(sem_name_.c_str(), O_CREAT, 0644, 0); +} + +void RwSharedMemory::Start() { + sem_post(sem_); +} + +void RwSharedMemory::Handle(HandleFn handle_fn) { + sem_wait(sem_); + handle_fn(data_); + sem_post(sem_); +} + +void RwSharedMemory::Close() { + LAVA_ASSERT_INT(sem_close(sem_), 0); +} + +void SharedMemManager::DeleteAllSharedMemory() { + if (alloc_pid_ != getpid()) + return; + LAVA_DEBUG(LOG_SMMP, "Delete: Number of shm to free: %zd.\n", + shm_fd_strs_.size()); + LAVA_DEBUG(LOG_SMMP, "Delete: Number of sem to free: %zd.\n", + sem_p_strs_.size()); + for (auto const& it : shm_fd_strs_) { + LAVA_ASSERT_INT(shm_unlink(it.second.c_str()), 0); + LAVA_DEBUG(LOG_SMMP, "Shm fd and name close: %s %d\n", + it.second.c_str(), it.first); + LAVA_ASSERT_INT(close(it.first), 0); + } + for (auto const& it : shm_mmap_) { + LAVA_ASSERT_INT(munmap(it.first, it.second), 0); + } + for (auto const& it : sem_p_strs_) { + LAVA_ASSERT_INT(sem_close(it.first), 0); + LAVA_ASSERT_INT(sem_unlink(it.second.c_str()), 0); + } + sem_p_strs_.clear(); + shm_fd_strs_.clear(); + shm_mmap_.clear(); +} + +SharedMemManager SharedMemManager::smm_; + +SharedMemManager& GetSharedMemManagerSingleton() { + SharedMemManager &smm = SharedMemManager::smm_; + return smm; +} +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/shmem/shm.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/shmem/shm.h new file mode 100644 index 000000000..82101db64 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/shmem/shm.h @@ -0,0 +1,144 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef CHANNEL_SHMEM_SHM_H_ +#define CHANNEL_SHMEM_SHM_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace message_infrastructure { + +#define SHM_FLAG O_RDWR | O_CREAT +#define SHM_MODE S_IRUSR | S_IWUSR | S_IRGRP | S_IWGRP | S_IROTH | S_IWOTH + +using HandleFn = std::function; + +class SharedMemory { + public: + SharedMemory() {} + SharedMemory(const size_t &mem_size, void* mmap, const int &key); + SharedMemory(const size_t &mem_size, void* mmap); + ~SharedMemory(); + void Start(); + bool Load(HandleFn consume_fn); + void BlockLoad(HandleFn consume_fn); + void Read(HandleFn consume_fn); + void Store(HandleFn store_fn); + void Close(); + bool TryProbe(); + void InitSemaphore(sem_t* req, sem_t *ack); + int GetDataElem(int offset); + std::string GetReq(); + std::string GetAck(); + + private: + size_t size_; + std::string req_name_ = "req"; + std::string ack_name_ = "ack"; + sem_t *req_; + sem_t *ack_; + void *data_ = nullptr; +}; + +class RwSharedMemory { + public: + RwSharedMemory(const size_t &mem_size, void* mmap, const int &key); + ~RwSharedMemory(); + void InitSemaphore(); + void Start(); + void Handle(HandleFn handle_fn); + void Close(); + + private: + size_t size_; + std::string sem_name_ = "sem"; + sem_t *sem_; + void *data_; +}; + +// SharedMemory object needs to be transfered to ShmemPort. +// RwSharedMemory object needs to be transfered to ShmemPort. +// Also need to be handled in SharedMemManager. +// Use std::shared_ptr. +using SharedMemoryPtr = std::shared_ptr; +using RwSharedMemoryPtr = std::shared_ptr; + +class SharedMemManager { + public: + SharedMemManager(const SharedMemManager&) = delete; + SharedMemManager(SharedMemManager&&) = delete; + SharedMemManager& operator=(const SharedMemManager&) = delete; + SharedMemManager& operator=(SharedMemManager&&) = delete; + template + std::shared_ptr AllocChannelSharedMemory(const size_t &mem_size) { + int random = std::rand(); + std::string str = shm_str_ + std::to_string(random); + int shmfd = shm_open(str.c_str(), SHM_FLAG, SHM_MODE); + LAVA_DEBUG(LOG_SMMP, "Shm fd and name open: %s %d\n", + str.c_str(), shmfd); + if (shmfd == -1) { + LAVA_LOG_FATAL("Create shared memory object failed.\n"); + } + int err = ftruncate(shmfd, mem_size); + if (err == -1) { + LAVA_LOG_FATAL("Resize shared memory segment failed.\n"); + } + shm_fd_strs_.insert({shmfd, str}); + void *mmap_address = mmap(nullptr, mem_size, PROT_READ | PROT_WRITE, + MAP_SHARED, shmfd, 0); + if (mmap_address == reinterpret_cast(-1)) { + LAVA_LOG_ERR("Get shmem address error, errno: %d\n", errno); + LAVA_DUMP(1, "size: %ld, shmfd_: %d\n", mem_size, shmfd); + } + shm_mmap_.insert({mmap_address, mem_size}); + std::shared_ptr shm = + std::make_shared(mem_size, mmap_address, random); + std::string req_name = shm->GetReq(); + std::string ack_name = shm->GetAck(); + sem_t *req = sem_open(req_name.c_str(), O_CREAT, 0644, 0); + sem_t *ack = sem_open(ack_name.c_str(), O_CREAT, 0644, 1); + shm->InitSemaphore(req, ack); + sem_p_strs_.insert({req, req_name}); + sem_p_strs_.insert({ack, ack_name}); + return shm; + } + + void DeleteAllSharedMemory(); + friend SharedMemManager &GetSharedMemManagerSingleton(); + + private: + SharedMemManager() { + std::srand(std::time(nullptr)); + alloc_pid_ = getpid(); + } + ~SharedMemManager() = default; + std::unordered_map shm_fd_strs_; + std::unordered_map sem_p_strs_; + std::unordered_map shm_mmap_; + static SharedMemManager smm_; + std::string shm_str_ = "shm"; + int alloc_pid_; +}; + +SharedMemManager& GetSharedMemManagerSingleton(); + +} // namespace message_infrastructure + +#endif // CHANNEL_SHMEM_SHM_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/shmem/shmem_channel.cc b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/shmem/shmem_channel.cc new file mode 100644 index 000000000..d0ac89bed --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/shmem/shmem_channel.cc @@ -0,0 +1,47 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include + +namespace message_infrastructure { + +ShmemChannel::ShmemChannel(const std::string &src_name, + const std::string &dst_name, + const size_t &size, + const size_t &nbytes) { + size_t shmem_size = nbytes + sizeof(MetaData); + + shm_ = GetSharedMemManagerSingleton().AllocChannelSharedMemory( + shmem_size); + + send_port_ = std::make_shared(src_name, shm_, + size, shmem_size); + if (size > 1) { + recv_port_ = std::make_shared(dst_name, shm_, + size, shmem_size); + } else { + recv_port_ = std::make_shared(dst_name, shm_, + shmem_size); + } +} + +AbstractSendPortPtr ShmemChannel::GetSendPort() { + return send_port_; +} + +AbstractRecvPortPtr ShmemChannel::GetRecvPort() { + return recv_port_; +} + +std::shared_ptr GetShmemChannel(const size_t &size, + const size_t &nbytes, + const std::string &src_name, + const std::string &dst_name) { + return (std::make_shared(src_name, + dst_name, + size, + nbytes)); +} +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/shmem/shmem_channel.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/shmem/shmem_channel.h new file mode 100644 index 000000000..beb7f393a --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/shmem/shmem_channel.h @@ -0,0 +1,41 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef CHANNEL_SHMEM_SHMEM_CHANNEL_H_ +#define CHANNEL_SHMEM_SHMEM_CHANNEL_H_ + +#include +#include +#include +#include + +#include +#include + +namespace message_infrastructure { + +class ShmemChannel : public AbstractChannel { + public: + ShmemChannel() = delete; + ShmemChannel(const std::string &src_name, + const std::string &dst_name, + const size_t &size, + const size_t &nbytes); + ~ShmemChannel() override {} + AbstractSendPortPtr GetSendPort(); + AbstractRecvPortPtr GetRecvPort(); + private: + SharedMemoryPtr shm_ = nullptr; + ShmemSendPortPtr send_port_ = nullptr; + AbstractRecvPortPtr recv_port_ = nullptr; +}; + +std::shared_ptr GetShmemChannel(const size_t &size, + const size_t &nbytes, + const std::string &src_name, + const std::string &dst_name); + +} // namespace message_infrastructure + +#endif // CHANNEL_SHMEM_SHMEM_CHANNEL_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/shmem/shmem_port.cc b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/shmem/shmem_port.cc new file mode 100644 index 000000000..b5ac171d1 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/shmem/shmem_port.cc @@ -0,0 +1,176 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include +#include +#include +#include +#include // NOLINT +#include // NOLINT +#include +#include +#include +#include +#include + +namespace message_infrastructure { + +namespace { + +void MetaDataPtrFromPointer(const MetaDataPtr &ptr, void *p, int nbytes) { + std::memcpy(ptr.get(), p, sizeof(MetaData)); + int len = ptr->elsize * ptr->total_size; + if (len > nbytes) { + LAVA_LOG_ERR("Recv %d data but max support %d length\n", len, nbytes); + len = nbytes; + } + LAVA_DEBUG(LOG_SMMP, "data len: %d, nbytes: %d\n", len, nbytes); + ptr->mdata = std::calloc(len, 1); + if (ptr->mdata == nullptr) { + LAVA_LOG_ERR("alloc failed, errno: %d\n", errno); + } + LAVA_DEBUG(LOG_SMMP, "memory allocates: %p\n", ptr->mdata); + std::memcpy(ptr->mdata, + reinterpret_cast(p) + sizeof(MetaData), len); + LAVA_DEBUG(LOG_SMMP, "Metadata created\n"); +} + +} // namespace + +template<> +void RecvQueue::FreeData(MetaDataPtr data) { + free(data->mdata); +} + +ShmemSendPort::ShmemSendPort(const std::string &name, + SharedMemoryPtr shm, + const size_t &size, + const size_t &nbytes) + : AbstractSendPort(name, size, nbytes), shm_(shm), done_(false) +{} + +void ShmemSendPort::Start() { + shm_->Start(); +} + +void ShmemSendPort::Send(DataPtr metadata) { + auto mdata = reinterpret_cast(metadata.get()); + int len = mdata->elsize * mdata->total_size; + if (len > nbytes_ - sizeof(MetaData)) { + LAVA_LOG_ERR("Send data too large\n"); + } + shm_->Store([len, &metadata](void* data){ + char* cptr = reinterpret_cast(data); + std::memcpy(cptr, metadata.get(), sizeof(MetaData)); + cptr += sizeof(MetaData); + std::memcpy(cptr, + reinterpret_cast(metadata.get())->mdata, + len); + }); +} + +bool ShmemSendPort::Probe() { + return false; +} + +void ShmemSendPort::Join() { + done_ = true; +} + +ShmemRecvPort::ShmemRecvPort(const std::string &name, + SharedMemoryPtr shm, + const size_t &size, + const size_t &nbytes) + : AbstractRecvPort(name, size, nbytes), shm_(shm), done_(false) { + recv_queue_ = std::make_shared>(name_, size_); +} + +ShmemRecvPort::~ShmemRecvPort() { +} + +void ShmemRecvPort::Start() { + recv_queue_thread_ = std::thread( + &message_infrastructure::ShmemRecvPort::QueueRecv, this); +} + +void ShmemRecvPort::QueueRecv() { + while (!done_.load()) { + bool ret = false; + if (this->recv_queue_->AvailableCount() > 0) { + ret = shm_->Load([this](void* data){ + MetaDataPtr metadata_res = std::make_shared(); + MetaDataPtrFromPointer(metadata_res, data, + nbytes_ - sizeof(MetaData)); + this->recv_queue_->Push(metadata_res); + }); + } + if (!ret) { + helper::Sleep(); + } + } +} + +bool ShmemRecvPort::Probe() { + return recv_queue_->Probe(); +} + +MetaDataPtr ShmemRecvPort::Recv() { + return recv_queue_->Pop(true); +} + +void ShmemRecvPort::Join() { + if (!done_) { + done_ = true; + if (recv_queue_thread_.joinable()) + recv_queue_thread_.join(); + recv_queue_->Stop(); + } +} + +MetaDataPtr ShmemRecvPort::Peek() { + MetaDataPtr metadata_res = recv_queue_->Front(); + int mem_size = (nbytes_ - sizeof(MetaData) + 7) & (~0x7); + void * ptr = std::calloc(mem_size, 1); + if (ptr == nullptr) { + LAVA_LOG_ERR("alloc failed, errno: %d\n", errno); + } + LAVA_DEBUG(LOG_SMMP, "memory allocates: %p\n", ptr); + // memcpy to avoid double free + // or maintain a address:refcount map + std::memcpy(ptr, metadata_res->mdata, mem_size); + MetaDataPtr metadata = std::make_shared(); + std::memcpy(metadata.get(), metadata_res.get(), sizeof(MetaData)); + metadata->mdata = ptr; + return metadata; +} + +ShmemBlockRecvPort::ShmemBlockRecvPort(const std::string &name, + SharedMemoryPtr shm, const size_t &nbytes) + : AbstractRecvPort(name, 1, nbytes), shm_(shm) +{} + +MetaDataPtr ShmemBlockRecvPort::Recv() { + MetaDataPtr metadata_res = std::make_shared(); + shm_->BlockLoad([&metadata_res, this](void* data){ + MetaDataPtrFromPointer(metadata_res, data, + nbytes_ - sizeof(MetaData)); + }); + return metadata_res; +} + +MetaDataPtr ShmemBlockRecvPort::Peek() { + MetaDataPtr metadata_res = std::make_shared(); + shm_->Read([&metadata_res, this](void* data){ + MetaDataPtrFromPointer(metadata_res, data, + nbytes_ - sizeof(MetaData)); + }); + return metadata_res; +} + +bool ShmemBlockRecvPort::Probe() { + return shm_->TryProbe(); +} + +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/shmem/shmem_port.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/shmem/shmem_port.h new file mode 100644 index 000000000..822bd6adc --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/shmem/shmem_port.h @@ -0,0 +1,89 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef CHANNEL_SHMEM_SHMEM_PORT_H_ +#define CHANNEL_SHMEM_SHMEM_PORT_H_ + +#include +#include +#include + +#include +#include +#include +#include +#include // NOLINT + +namespace message_infrastructure { + +template class RecvQueue; + +class ShmemSendPort final : public AbstractSendPort { + public: + ShmemSendPort() = delete; + ShmemSendPort(const std::string &name, + SharedMemoryPtr shm, + const size_t &size, + const size_t &nbytes); + ~ShmemSendPort() override {} + void Start(); + void Send(DataPtr metadata); + void Join(); + bool Probe(); + + private: + SharedMemoryPtr shm_ = nullptr; + std::atomic_bool done_; +}; + +// Users should be allowed to copy port objects. +// Use std::shared_ptr. +using ShmemSendPortPtr = std::shared_ptr; + +class ShmemBlockRecvPort final : public AbstractRecvPort { + public: + ShmemBlockRecvPort() = delete; + ShmemBlockRecvPort(const std::string &name, + SharedMemoryPtr shm, + const size_t &nbytes); + ~ShmemBlockRecvPort() override {} + void Start() {} + bool Probe(); + MetaDataPtr Recv(); + void Join() {} + MetaDataPtr Peek(); + + private: + SharedMemoryPtr shm_ = nullptr; +}; + +class ShmemRecvPort final : public AbstractRecvPort { + public: + ShmemRecvPort() = delete; + ShmemRecvPort(const std::string &name, + SharedMemoryPtr shm, + const size_t &size, + const size_t &nbytes); + ~ShmemRecvPort(); + void Start(); + bool Probe(); + MetaDataPtr Recv(); + void Join(); + MetaDataPtr Peek(); + void QueueRecv(); + + private: + SharedMemoryPtr shm_ = nullptr; + std::atomic_bool done_; + std::shared_ptr> recv_queue_; + std::thread recv_queue_thread_; +}; + +// Users should be allowed to copy port objects. +// Use std::shared_ptr. +using ShmemRecvPortPtr = std::shared_ptr; + +} // namespace message_infrastructure + +#endif // CHANNEL_SHMEM_SHMEM_PORT_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/socket/socket.cc b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/socket/socket.cc new file mode 100644 index 000000000..db07ae3c7 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/socket/socket.cc @@ -0,0 +1,66 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include +#include + +namespace message_infrastructure { + +SktManager::~SktManager() { + for (auto it = sockets_.begin(); it != sockets_.end(); it++) { + close(it->first); + close(it->second); + } + sockets_.clear(); +} + +SocketPair SktManager::AllocChannelSocket(size_t nbytes) { + SocketPair skt_pair; + int socket[2]; + int err = socketpair(AF_LOCAL, SOCK_SEQPACKET, 0, socket); + if (err == -1) { + LAVA_LOG_FATAL("Create socket object failed.\n"); + } + skt_pair.first = socket[0]; + skt_pair.second = socket[1]; + sockets_.push_back(skt_pair); + return skt_pair; +} + +SocketFile SktManager::AllocSocketFile(const std::string &addr_path) { + SocketFile skt_file; + if (std::string() == addr_path) { + LAVA_DEBUG(LOG_SKP, "Creating Socket File\n"); + std::srand(std::time(nullptr)); + do { + skt_file = SKT_TEMP_PATH + std::to_string(std::rand()); + } while (std::filesystem::exists(skt_file)); + } else { + skt_file = addr_path; + } + if (socket_files_.find(skt_file) != socket_files_.end()) { + LAVA_LOG_ERR("Skt File %s is alread used by the process\n", + skt_file.c_str()); + } + socket_files_.insert(skt_file); + return skt_file; +} + +bool SktManager::DeleteSocketFile(const std::string &addr_path) { + if (socket_files_.find(addr_path) == socket_files_.end()) { + LAVA_LOG_WARN(LOG_SKP, "Cannot delete exist file name\n"); + return false; + } + socket_files_.erase(addr_path); + return true; +} + +SktManager SktManager::sktm_; + +SktManager& GetSktManagerSingleton() { + SktManager &sktm = SktManager::sktm_; + return sktm; +} +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/socket/socket.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/socket/socket.h new file mode 100644 index 000000000..0194623ed --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/socket/socket.h @@ -0,0 +1,59 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef CHANNEL_SOCKET_SOCKET_H_ +#define CHANNEL_SOCKET_SOCKET_H_ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define SKT_TEMP_PATH "/tmp/skt_tmp_" +#define MAX_SKT_FILENAME_LENGTH 100 + +namespace message_infrastructure { + +using SocketPair = std::pair; +using SocketFile = std::string; + +class SktManager { + public: + SktManager(const SktManager&) = delete; + SktManager(SktManager&&) = delete; + SktManager& operator=(const SktManager&) = delete; + SktManager& operator=(SktManager&&) = delete; + + SocketPair AllocChannelSocket(size_t nbytes); + SocketFile AllocSocketFile(const std::string &addr_path); + bool DeleteSocketFile(const std::string &addr_path); + + friend SktManager &GetSktManagerSingleton(); + + private: + SktManager() = default; + ~SktManager(); + std::vector sockets_; + std::unordered_set socket_files_; + static SktManager sktm_; +}; + +SktManager& GetSktManagerSingleton(); + +} // namespace message_infrastructure + +#endif // CHANNEL_SOCKET_SOCKET_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/socket/socket_channel.cc b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/socket/socket_channel.cc new file mode 100644 index 000000000..3a76f18c5 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/socket/socket_channel.cc @@ -0,0 +1,62 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include +#include +#include + +namespace message_infrastructure { + +SocketChannel::SocketChannel(const std::string &src_name, + const std::string &dst_name, + const size_t &nbytes) { + SocketPair skt = GetSktManagerSingleton().AllocChannelSocket(nbytes); + send_port_ = std::make_shared(src_name, skt, nbytes); + recv_port_ = std::make_shared(dst_name, skt, nbytes); +} + +AbstractSendPortPtr SocketChannel::GetSendPort() { + return send_port_; +} + +AbstractRecvPortPtr SocketChannel::GetRecvPort() { + return recv_port_; +} + +std::shared_ptr GetSocketChannel(const size_t &nbytes, + const std::string &src_name, + const std::string &dst_name) { + return (std::make_shared(src_name, + dst_name, + nbytes)); +} + +TempSocketChannel::TempSocketChannel(const std::string &addr_path) { + addr_path_ = GetSktManagerSingleton().AllocSocketFile(addr_path); +} + +std::string TempSocketChannel::ChannelInfo() { + return addr_path_; +} + +AbstractRecvPortPtr TempSocketChannel::GetRecvPort() { + if (recv_port_ == nullptr) { + recv_port_ = std::make_shared(addr_path_); + } + return recv_port_; +} + +AbstractSendPortPtr TempSocketChannel::GetSendPort() { + if (send_port_ == nullptr) { + send_port_ = std::make_shared(addr_path_); + } + return send_port_; +} + +bool TempSocketChannel::Close() { + return GetSktManagerSingleton().DeleteSocketFile(addr_path_); +} + +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/socket/socket_channel.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/socket/socket_channel.h new file mode 100644 index 000000000..2a8d6b42c --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/socket/socket_channel.h @@ -0,0 +1,58 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef CHANNEL_SOCKET_SOCKET_CHANNEL_H_ +#define CHANNEL_SOCKET_SOCKET_CHANNEL_H_ + +#include +#include +#include +#include + +#include +#include + +namespace message_infrastructure { + +class SocketChannel : public AbstractChannel { + public: + SocketChannel() = delete; + SocketChannel(const std::string &src_name, + const std::string &dst_name, + const size_t &nbytes); + ~SocketChannel() override {} + AbstractSendPortPtr GetSendPort(); + AbstractRecvPortPtr GetRecvPort(); + private: + SocketPair skt_; + SocketSendPortPtr send_port_ = nullptr; + SocketRecvPortPtr recv_port_ = nullptr; +}; + +// Users should be allowed to copy channel objects. +// Use std::shared_ptr. +using SocketChannelPtr = std::shared_ptr; + +SocketChannelPtr GetSocketChannel(const size_t &nbytes, + const std::string &src_name, + const std::string &dst_name); + +class TempSocketChannel : public AbstractChannel { + public: + TempSocketChannel() = delete; + explicit TempSocketChannel(const std::string &addr_path); + ~TempSocketChannel() override {} + AbstractSendPortPtr GetSendPort(); + AbstractRecvPortPtr GetRecvPort(); + std::string ChannelInfo() override; + bool Close(); + private: + SocketFile addr_path_; + TempSocketRecvPortPtr recv_port_ = nullptr; + TempSocketSendPortPtr send_port_ = nullptr; +}; + +} // namespace message_infrastructure + +#endif // CHANNEL_SOCKET_SOCKET_CHANNEL_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/socket/socket_port.cc b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/socket/socket_port.cc new file mode 100644 index 000000000..36708e5f0 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/socket/socket_port.cc @@ -0,0 +1,215 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include +#include + +#include +#include +#include +#include // NOLINT +#include // NOLINT +#include +#include +#include // NOLINT +#include +#include + +namespace message_infrastructure { + +bool SocketWrite(int fd, void* data, size_t size) { + size_t length = write(fd, reinterpret_cast(data), size); + + if (length != size) { + // cppcheck-suppress conditionAlwaysFalse + if (length == -1) { + LAVA_LOG_ERR("Write socket failed.\n"); + return false; + } + LAVA_LOG_ERR("Write socket error, expected size: %zd, got size: %zd", + size, length); + return false; + } + return true; +} + +bool SocketRead(int fd, void* data, size_t size) { + char *ptr = reinterpret_cast(data); + while (size > 0) { + size_t length = read(fd, ptr, size); // Flawfinder: ignore + size -= length; + ptr += length; + if (length == 0) + break; + } + if (size) { + LAVA_LOG_ERR("Cannot recv all the data\n"); + return false; + } + return true; +} + +void SocketSendPort::Start() {} +void SocketSendPort::Send(DataPtr metadata) { + bool ret = false; + while (!ret) { + ret = SocketWrite(socket_.first, + reinterpret_cast(metadata.get()), + sizeof(MetaData)); + } + ret = false; + while (!ret) { + ret = SocketWrite(socket_.first, + reinterpret_cast(metadata.get())->mdata, + nbytes_); + } +} +void SocketSendPort::Join() { + close(socket_.first); + close(socket_.second); +} +bool SocketSendPort::Probe() { + LAVA_LOG_ERR("Not Support SocketSendPort Port Probe()\n"); + return false; +} + +void SocketRecvPort::Start() {} +bool SocketRecvPort::Probe() { + LAVA_LOG_ERR("Not Support SocketRecvPort Port Probe()\n"); + return false; +} +MetaDataPtr SocketRecvPort::Recv() { + bool ret = false; + MetaDataPtr metadata = std::make_shared(); + ret = SocketRead(socket_.second, metadata.get(), sizeof(MetaData)); + if (!ret) { + metadata.reset(); + return metadata; + } + void *mdata = std::calloc(nbytes_, 1); + if (mdata == nullptr) { + LAVA_LOG_FATAL("Memory alloc failed, errno: %d\n", errno); + } + ret = SocketRead(socket_.second, mdata, nbytes_); + metadata->mdata = mdata; + if (!ret) { + metadata.reset(); + free(mdata); + } + return metadata; +} +void SocketRecvPort::Join() { + close(socket_.first); + close(socket_.second); +} +MetaDataPtr SocketRecvPort::Peek() { + return Recv(); +} + +TempSocketSendPort::TempSocketSendPort(const SocketFile &addr_path) { + name_ = "SendPort" + addr_path; + addr_path_ = addr_path; + cfd_ = socket(AF_UNIX, SOCK_STREAM, 0); + if (cfd_ == -1) { + LAVA_LOG_ERR("Cannot Create Socket Domain File Descripter\n"); + } + + size_t skt_addr_len = sizeof(sa_family_t) + addr_path_.size(); + sockaddr *skt_addr = reinterpret_cast(malloc(skt_addr_len)); + skt_addr->sa_family = AF_UNIX; + memcpy(skt_addr->sa_data, addr_path.c_str(), addr_path_.size()); + + if (connect(cfd_, skt_addr, skt_addr_len) == -1) { + LAVA_LOG_ERR("Cannot bind socket domain\n"); + } +} +void TempSocketSendPort::Start() {} +bool TempSocketSendPort::Probe() { + LAVA_LOG_ERR("Not Support TempSocket Port Probe()\n"); + return false; +} +void TempSocketSendPort::Send(DataPtr data) { + auto metadata = reinterpret_cast(data.get()); + bool flag; + flag = SocketWrite(cfd_, metadata, sizeof(MetaData)); + if (!flag) { + LAVA_LOG_ERR("TempSkt Send data header Error\n"); + } + flag = SocketWrite(cfd_, + metadata->mdata, + metadata->total_size * metadata->elsize); + if (!flag) { + LAVA_LOG_ERR("TempSkt Send data error\n"); + } + LAVA_DEBUG(LOG_SKP, + "Send %ld data\n", + metadata->total_size * metadata->elsize); +} +void TempSocketSendPort::Join() { + close(cfd_); + GetSktManagerSingleton().DeleteSocketFile(addr_path_); +} + +TempSocketRecvPort::TempSocketRecvPort(const SocketFile &addr_path) { + name_ = "RecvPort_" + addr_path; + addr_path_ = addr_path; + sfd_ = socket(AF_UNIX, SOCK_STREAM, 0); + if (sfd_ == -1) { + LAVA_LOG_ERR("Cannot Create Socket Domain File Descripter\n"); + } + + size_t skt_addr_len = sizeof(sa_family_t) + addr_path_.size(); + sockaddr *skt_addr = reinterpret_cast(malloc(skt_addr_len)); + skt_addr->sa_family = AF_UNIX; + memcpy(&skt_addr->sa_data[0], addr_path.c_str(), addr_path_.size()); + LAVA_DEBUG(LOG_SKP, + "the path: %s, %zd\n", + &skt_addr->sa_data[0], + addr_path_.size()); + if (bind(sfd_, skt_addr, skt_addr_len) == -1) { + LAVA_LOG_ERR("Cannot bind socket domain\n"); + } +} +void TempSocketRecvPort::Start() { + if (listen(sfd_, 1) == -1) { + LAVA_LOG_ERR("Cannot Listen service socket file, %d\n", errno); + } +} +bool TempSocketRecvPort::Probe() { + LAVA_LOG_ERR("Not Support TempSocket Port Probe()\n"); + return false; +} +MetaDataPtr TempSocketRecvPort::Recv() { + bool flag; + int cfd = accept(sfd_, nullptr, nullptr); + if (cfd == -1) { + LAVA_LOG_ERR("Cannot accept the connection\n"); + } + MetaDataPtr data = std::make_shared(); + flag = SocketRead(cfd, data.get(), sizeof(MetaData)); + if (!flag) { + LAVA_LOG_ERR("TempSkt Recv data header error\n"); + } + void *ptr = calloc(data->elsize * data->total_size, 1); + flag = SocketRead(cfd, ptr, data->elsize * data->total_size); + if (!flag) { + LAVA_LOG_ERR("TempSkt Recv data error\n"); + } + LAVA_DEBUG(LOG_SKP, "Recv %ld data\n", data->elsize * data->total_size); + data->mdata = ptr; + close(cfd); + return data; +} +MetaDataPtr TempSocketRecvPort::Peek() { + LAVA_LOG_ERR("Not Support TempSocket Port Peek()\n"); + return nullptr; +} +void TempSocketRecvPort::Join() { + close(sfd_); + unlink(addr_path_.c_str()); + GetSktManagerSingleton().DeleteSocketFile(addr_path_); +} + +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/socket/socket_port.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/socket/socket_port.h new file mode 100644 index 000000000..609364064 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel/socket/socket_port.h @@ -0,0 +1,100 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef CHANNEL_SOCKET_SOCKET_PORT_H_ +#define CHANNEL_SOCKET_SOCKET_PORT_H_ + +#include +#include +#include + +#include +#include +#include + +namespace message_infrastructure { + +class SocketSendPort final : public AbstractSendPort { + public: + SocketSendPort() = delete; + SocketSendPort(const std::string &name, + const SocketPair &socket, + const size_t &nbytes) : + AbstractSendPort(name, 1, nbytes), + name_(name), nbytes_(nbytes), socket_(socket) {} + ~SocketSendPort() override {} + void Start(); + void Send(DataPtr metadata); + void Join(); + bool Probe(); + + private: + std::string name_; + size_t nbytes_; + SocketPair socket_; +}; + +// Users should be allowed to copy port objects. +// Use std::shared_ptr. +using SocketSendPortPtr = std::shared_ptr; + +class SocketRecvPort final : public AbstractRecvPort { + public: + SocketRecvPort() = delete; + SocketRecvPort(const std::string &name, + const SocketPair &socket, + const size_t &nbytes) : + AbstractRecvPort(name, 1, nbytes), + name_(name), nbytes_(nbytes), socket_(socket) {} + ~SocketRecvPort() override {} + void Start(); + bool Probe(); + MetaDataPtr Recv(); + void Join(); + MetaDataPtr Peek(); + + private: + std::string name_; + size_t nbytes_; + SocketPair socket_; +}; + +// Users should be allowed to copy port objects. +// Use std::shared_ptr. +using SocketRecvPortPtr = std::shared_ptr; + +class TempSocketSendPort final : public AbstractSendPort { + public: + TempSocketSendPort() = delete; + TempSocketSendPort(const SocketFile &addr_path); + ~TempSocketSendPort() override {}; + void Start(); + void Send(DataPtr metadata); + void Join(); + bool Probe(); + private: + int cfd_; + SocketFile addr_path_; +}; +using TempSocketSendPortPtr = std::shared_ptr; + +class TempSocketRecvPort final : public AbstractRecvPort { + public: + TempSocketRecvPort() = delete; + TempSocketRecvPort(const SocketFile &addr_path); + ~TempSocketRecvPort() override {} + void Start(); + bool Probe(); + void Join(); + MetaDataPtr Recv(); + MetaDataPtr Peek(); + private: + int sfd_; + SocketFile addr_path_; +}; +using TempSocketRecvPortPtr = std::shared_ptr; + +} // namespace message_infrastructure + +#endif // CHANNEL_SOCKET_SOCKET_PORT_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel_proxy.cc b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel_proxy.cc new file mode 100644 index 000000000..fda6cede6 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel_proxy.cc @@ -0,0 +1,124 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include +#include + +namespace message_infrastructure { + +ChannelProxy::ChannelProxy(const ChannelType &channel_type, + const size_t &size, + const size_t &nbytes, + const std::string &src_name, + const std::string &dst_name, + py::tuple shape, + py::object type) { + ChannelFactory &channel_factory = GetChannelFactory(); + channel_ = channel_factory.GetChannel(channel_type, + size, + nbytes, + src_name, + dst_name); + send_port_ = std::make_shared(channel_type, + channel_->GetSendPort(), + shape, + type); + recv_port_ = std::make_shared(channel_type, + channel_->GetRecvPort(), + shape, + type); +} +SendPortProxyPtr ChannelProxy::GetSendPort() { + return send_port_; +} +RecvPortProxyPtr ChannelProxy::GetRecvPort() { + return recv_port_; +} + +TempChannelProxy::TempChannelProxy() { + ChannelFactory &channel_factory = GetChannelFactory(); + channel_ = channel_factory.GetTempChannel(std::string()); +} + +TempChannelProxy::TempChannelProxy(const std::string &addr_path) { + ChannelFactory &channel_factory = GetChannelFactory(); + channel_ = channel_factory.GetTempChannel(addr_path); +} + +SendPortProxyPtr TempChannelProxy::GetSendPort() { + return std::make_shared(ChannelType::TEMPCHANNEL, + channel_->GetSendPort()); +} +RecvPortProxyPtr TempChannelProxy::GetRecvPort() { + return std::make_shared(ChannelType::TEMPCHANNEL, + channel_->GetRecvPort()); +} + +std::string TempChannelProxy::GetAddrPath() { + return channel_->ChannelInfo(); +} + +#if defined(GRPC_CHANNEL) +GetRPCChannelProxy::GetRPCChannelProxy(const std::string &url, + const int &port, + const std::string &src_name, + const std::string &dst_name, + const size_t &size) { + ChannelFactory &channel_factory = GetChannelFactory(); + channel_ = channel_factory.GetRPCChannel(url, port, src_name, dst_name, size); + send_port_ = std::make_shared(channel_type, + channel_->GetSendPort()); + recv_port_ = std::make_shared(channel_type, + channel_->GetRecvPort()); +} +GetRPCChannelProxy::GetRPCChannelProxy(const std::string &src_name, + const std::string &dst_name, + const size_t &size) { + ChannelFactory &channel_factory = GetChannelFactory(); + channel_ = channel_factory.GetDefRPCChannel(src_name, dst_name, size); + send_port_ = std::make_shared(channel_type, + channel_->GetSendPort()); + recv_port_ = std::make_shared(channel_type, + channel_->GetRecvPort()); +} +SendPortProxyPtr GetRPCChannelProxy::GetSendPort() { + return send_port_; +} +RecvPortProxyPtr GetRPCChannelProxy::GetRecvPort() { + return recv_port_; +} +#endif + +#if defined(DDS_CHANNEL) +GetDDSChannelProxy::GetDDSChannelProxy( + const std::string &src_name, + const std::string &dst_name, + const std::string &topic_name, + const size_t &size, + const size_t &nbytes, + const DDSTransportType &dds_transfer_type, + const DDSBackendType &dds_backend) { + ChannelFactory &channel_factory = GetChannelFactory(); + channel_ = channel_factory.GetDDSChannel(src_name, + dst_name, + topic_name, + size, + nbytes, + dds_transfer_type, + dds_backend); + send_port_ = std::make_shared(ChannelType::DDSCHANNEL, + channel_->GetSendPort()); + recv_port_ = std::make_shared(ChannelType::DDSCHANNEL, + channel_->GetRecvPort()); +} + +SendPortProxyPtr GetDDSChannelProxy::GetSendPort() { + return send_port_; +} +RecvPortProxyPtr GetDDSChannelProxy::GetRecvPort() { + return recv_port_; +} +#endif +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel_proxy.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel_proxy.h new file mode 100644 index 000000000..ae4f8d68d --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/channel_proxy.h @@ -0,0 +1,90 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef CHANNEL_PROXY_H_ +#define CHANNEL_PROXY_H_ + +#include +#include +#include +#if defined(DDS_CHANNEL) +#include +#endif + +#include +#include + +namespace message_infrastructure { + +class ChannelProxy { + public: + ChannelProxy(const ChannelType &channel_type, + const size_t &size, + const size_t &nbytes, + const std::string &src_name, + const std::string &dst_name, + py::tuple shape = py::make_tuple(), + py::object type = py::none()); + SendPortProxyPtr GetSendPort(); + RecvPortProxyPtr GetRecvPort(); + private: + AbstractChannelPtr channel_ = nullptr; + SendPortProxyPtr send_port_ = nullptr; + RecvPortProxyPtr recv_port_ = nullptr; +}; + +class TempChannelProxy { + public: + TempChannelProxy(); + explicit TempChannelProxy(const std::string &addr_path); + SendPortProxyPtr GetSendPort(); + RecvPortProxyPtr GetRecvPort(); + std::string GetAddrPath(); + private: + AbstractChannelPtr channel_ = nullptr; +}; + +#if defined(GRPC_CHANNEL) +class GetRPCChannelProxy { + public: + GetRPCChannelProxy(const std::string &url, + const int &port, + const std::string &src_name, + const std::string &dst_name, + const size_t &size); + GetRPCChannelProxy(const std::string &src_name, + const std::string &dst_name, + const size_t &size); + SendPortProxyPtr GetSendPort(); + RecvPortProxyPtr GetRecvPort(); + private: + ChannelType channel_type = ChannelType::RPCCHANNEL; + AbstractChannelPtr channel_ = nullptr; + SendPortProxyPtr send_port_ = nullptr; + RecvPortProxyPtr recv_port_ = nullptr; +}; +#endif + +#if defined(DDS_CHANNEL) +class GetDDSChannelProxy { + public: + GetDDSChannelProxy(const std::string &src_name, + const std::string &dst_name, + const std::string &topic_name, + const size_t &size, + const size_t &nbytes, + const DDSTransportType &dds_transfer_type, + const DDSBackendType &dds_backend); + SendPortProxyPtr GetSendPort(); + RecvPortProxyPtr GetRecvPort(); + private: + AbstractChannelPtr channel_ = nullptr; + SendPortProxyPtr send_port_ = nullptr; + RecvPortProxyPtr recv_port_ = nullptr; +}; +#endif + +} // namespace message_infrastructure + +#endif // CHANNEL_PROXY_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/abstract_actor.cc b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/abstract_actor.cc new file mode 100644 index 000000000..14ee9f83b --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/abstract_actor.cc @@ -0,0 +1,96 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include +#include + +namespace message_infrastructure { + +AbstractActor::AbstractActor(AbstractActor::TargetFn target_fn) + : target_fn_(target_fn) { + ctl_shm_ = GetSharedMemManagerSingleton() + .AllocChannelSharedMemory(sizeof(int)); + ctl_shm_->Start(); +} + +void AbstractActor::Control(const ActorCmd cmd) { + ctl_shm_->Store([cmd](void* data){ + auto ctrl_cmd = reinterpret_cast(data); + *ctrl_cmd = cmd; + LAVA_DEBUG(LOG_MP, "Cmd Get: %d\n", static_cast(cmd)); + }); +} + +void AbstractActor::HandleCmd() { + while (actor_status_.load() < static_cast(ActorStatus::StatusStopped)) { + auto ret = ctl_shm_->Load([this](void *data){ + auto ctrl_status = reinterpret_cast(data); + if (*ctrl_status == static_cast(ActorCmd::CmdStop)) { + actor_status_.store(static_cast(ActorStatus::StatusStopped)); + } else if (*ctrl_status == static_cast(ActorCmd::CmdPause)) { + actor_status_.store(static_cast(ActorStatus::StatusPaused)); + } else if (*ctrl_status == static_cast(ActorCmd::CmdRun)) { + actor_status_.store(static_cast(ActorStatus::StatusRunning)); + } + }); + if (!ret) { + helper::Sleep(); + } + } +} + +bool AbstractActor::SetStatus(ActorStatus status) { + LAVA_DEBUG(LOG_MP, "Set Status: %d\n", static_cast(status)); + auto const curr_status = actor_status_.load(); + if (curr_status >= static_cast(ActorStatus::StatusStopped) + && static_cast(status) < curr_status) { + return false; + } + actor_status_.store(static_cast(status)); + return true; +} + +ActorStatus AbstractActor::GetStatus() { + return static_cast(actor_status_.load()); +} + +void AbstractActor::SetStopFn(StopFn stop_fn) { + stop_fn_ = stop_fn; +} + +void AbstractActor::Run() { + InitStatus(); + while (true) { + if (actor_status_.load() >= static_cast(ActorStatus::StatusStopped)) { + break; + } + if (actor_status_.load() == static_cast(ActorStatus::StatusRunning)) { + target_fn_(); + LAVA_LOG(LOG_MP, "Actor:ActorStatus:%d\n", static_cast(GetStatus())); + if (!loop_run_) { + Control(ActorCmd::CmdStop); + break; + } + } else { + // pause status + helper::Sleep(); + } + } + if (handle_cmd_thread_.joinable()) { + handle_cmd_thread_.join(); + } + if (stop_fn_ != nullptr && + actor_status_.load() != static_cast(ActorStatus::StatusTerminated)) { + stop_fn_(); + } + LAVA_LOG(LOG_ACTOR, "child exist, pid:%d\n", pid_); +} + +void AbstractActor::InitStatus() { + actor_status_.store(static_cast(ActorStatus::StatusRunning)); + handle_cmd_thread_ = std::thread(&AbstractActor::HandleCmd, this); +} + +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/abstract_actor.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/abstract_actor.h new file mode 100644 index 000000000..424e0f911 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/abstract_actor.h @@ -0,0 +1,81 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef CORE_ABSTRACT_ACTOR_H_ +#define CORE_ABSTRACT_ACTOR_H_ + +#include +#include +#include +#include +#include +#include +#include // NOLINT + +namespace message_infrastructure { + +enum class ActorType { + RuntimeActor = 0, + RuntimeServiceActor = 1, + ProcessModelActor = 2 +}; + +enum class ActorStatus { + StatusError = -1, + StatusRunning = 0, + StatusPaused = 1, + StatusStopped = 2, + StatusTerminated = 3, +}; + +enum class ActorCmd { + CmdRun = 0, + CmdStop = -1, + CmdPause = -2 +}; + +struct ActorCtrlStatus { + ActorCmd cmd; + ActorStatus status; +}; + +class AbstractActor { + public: + using ActorPtr = AbstractActor *; + using TargetFn = std::function; + using StopFn = std::function; + + explicit AbstractActor(TargetFn target_fn); + virtual ~AbstractActor() = default; + virtual int ForceStop() = 0; + virtual int Wait() = 0; + virtual ProcessType Create() = 0; + void Control(const ActorCmd cmd); + ActorStatus GetStatus(); + bool SetStatus(ActorStatus status); + void SetStopFn(StopFn stop_fn); + int GetPid() { + return pid_; + } + + protected: + void Run(); + int pid_ = -1; + + private: + SharedMemoryPtr ctl_shm_; + std::atomic actor_status_; + std::thread handle_cmd_thread_; + TargetFn target_fn_ = nullptr; + StopFn stop_fn_ = nullptr; + bool loop_run_ = false; + void InitStatus(); + void HandleCmd(); +}; + +using SharedActorPtr = std::shared_ptr; + +} // namespace message_infrastructure + +#endif // CORE_ABSTRACT_ACTOR_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/abstract_channel.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/abstract_channel.h new file mode 100644 index 000000000..fcc1b8d3f --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/abstract_channel.h @@ -0,0 +1,33 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef CORE_ABSTRACT_CHANNEL_H_ +#define CORE_ABSTRACT_CHANNEL_H_ + +#include +#include +#include + + +namespace message_infrastructure { + +class AbstractChannel { + public: + virtual ~AbstractChannel() = default; + ChannelType channel_type_; + + virtual AbstractSendPortPtr GetSendPort() = 0; + virtual AbstractRecvPortPtr GetRecvPort() = 0; + virtual std::string ChannelInfo() { + return std::string(); + } +}; + +// Users should be allowed to copy channel objects. +// Use std::shared_ptr. +using AbstractChannelPtr = std::shared_ptr; + +} // namespace message_infrastructure + +#endif // CORE_ABSTRACT_CHANNEL_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/abstract_port.cc b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/abstract_port.cc new file mode 100644 index 000000000..dcbbd129b --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/abstract_port.cc @@ -0,0 +1,20 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include + +namespace message_infrastructure { + +AbstractPort::AbstractPort( + const std::string &name, const size_t &size, const size_t &nbytes) + : name_(name), size_(size), nbytes_(nbytes) +{} + +std::string AbstractPort::Name() { + return name_; +} +size_t AbstractPort::Size() { + return size_; +} +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/abstract_port.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/abstract_port.h new file mode 100644 index 000000000..b7f159dee --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/abstract_port.h @@ -0,0 +1,62 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef CORE_ABSTRACT_PORT_H_ +#define CORE_ABSTRACT_PORT_H_ + +#include +#include +#include +#include + +namespace message_infrastructure { + +class AbstractPort { + public: + AbstractPort(const std::string &name, const size_t &size, + const size_t &nbytes); + AbstractPort() = default; + virtual ~AbstractPort() = default; + + std::string Name(); + size_t Size(); + virtual void Start() = 0; + virtual void Join() = 0; + virtual bool Probe() = 0; + + protected: + std::string name_; + size_t size_; + size_t nbytes_; +}; + +class AbstractSendPort : public AbstractPort { + public: + using AbstractPort::AbstractPort; + virtual ~AbstractSendPort() = default; + virtual void Start() = 0; + virtual void Send(DataPtr data) = 0; + virtual void Join() = 0; +}; + +class AbstractRecvPort : public AbstractPort { + public: + using AbstractPort::AbstractPort; + virtual ~AbstractRecvPort() = default; + virtual void Start() = 0; + virtual MetaDataPtr Recv() = 0; + virtual MetaDataPtr Peek() = 0; + virtual void Join() = 0; +}; + +// Users should be allowed to copy port objects. +// Use std::shared_ptr. +using AbstractSendPortPtr = std::shared_ptr; +using AbstractRecvPortPtr = std::shared_ptr; +using SendPortList = std::list; +using RecvPortList = std::list; + +} // namespace message_infrastructure + +#endif // CORE_ABSTRACT_PORT_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/abstract_port_implementation.cc b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/abstract_port_implementation.cc new file mode 100644 index 000000000..4272198d1 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/abstract_port_implementation.cc @@ -0,0 +1,42 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include + +namespace message_infrastructure { + +AbstractPortImplementation::AbstractPortImplementation( + const SendPortList &send_ports, + const RecvPortList &recv_ports) + : send_ports_(send_ports), recv_ports_(recv_ports) + {} +AbstractPortImplementation::AbstractPortImplementation( + const SendPortList &send_ports) + : send_ports_(send_ports) + {} +AbstractPortImplementation::AbstractPortImplementation( + const RecvPortList &recv_ports) + : recv_ports_(recv_ports) + {} + +int AbstractPortImplementation::Start() { + for (auto port : send_ports_) { + port->Start(); + } + for (auto port : recv_ports_) { + port->Start(); + } + return 0; +} + +int AbstractPortImplementation::Join() { + for (auto port : send_ports_) { + port->Join(); + } + for (auto port : recv_ports_) { + port->Join(); + } + return 0; +} +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/abstract_port_implementation.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/abstract_port_implementation.h new file mode 100644 index 000000000..29021769d --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/abstract_port_implementation.h @@ -0,0 +1,29 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef CORE_ABSTRACT_PORT_IMPLEMENTATION_H_ +#define CORE_ABSTRACT_PORT_IMPLEMENTATION_H_ + +#include + +namespace message_infrastructure { + +class AbstractPortImplementation { + public: + explicit AbstractPortImplementation(const SendPortList &send_ports, + const RecvPortList &recv_ports); + explicit AbstractPortImplementation(const RecvPortList &recv_ports); + explicit AbstractPortImplementation(const SendPortList &send_ports); + virtual ~AbstractPortImplementation() = default; + int Start(); + int Join(); + + protected: + SendPortList send_ports_; + RecvPortList recv_ports_; +}; + +} // namespace message_infrastructure + +#endif // CORE_ABSTRACT_PORT_IMPLEMENTATION_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/channel_factory.cc b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/channel_factory.cc new file mode 100644 index 000000000..51c322562 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/channel_factory.cc @@ -0,0 +1,84 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include +#include +#include +#include +#if defined(GRPC_CHANNEL) +#include +#endif + +#if defined(DDS_CHANNEL) +#include +#endif +namespace message_infrastructure { + +AbstractChannelPtr ChannelFactory::GetChannel(const ChannelType &channel_type, + const size_t &size, + const size_t &nbytes, + const std::string &src_name, + const std::string &dst_name) { + switch (channel_type) { +#if defined(DDS_CHANNEL) + case ChannelType::DDSCHANNEL: + return GetDefaultDDSChannel(nbytes, size, src_name, dst_name); +#endif + case ChannelType::SOCKETCHANNEL: + return GetSocketChannel(nbytes, src_name, dst_name); + default: + return GetShmemChannel(size, nbytes, src_name, dst_name); + } + return nullptr; +} + +AbstractChannelPtr ChannelFactory::GetTempChannel( + const std::string &addr_path) { + return std::make_shared(addr_path); +} + +#if defined(DDS_CHANNEL) +AbstractChannelPtr ChannelFactory::GetDDSChannel( + const std::string &src_name, + const std::string &dst_name, + const std::string &topic_name, + const size_t &size, + const size_t &nbytes, + const DDSTransportType &dds_transfer_type, + const DDSBackendType &dds_backend) { + return std::make_shared(src_name, + dst_name, + topic_name, + size, + nbytes, + dds_transfer_type, + dds_backend); +} +#endif + +#if defined(GRPC_CHANNEL) +AbstractChannelPtr ChannelFactory::GetRPCChannel(const std::string &url, + const int &port, + const std::string &src_name, + const std::string &dst_name, + const size_t &size) { + return std::make_shared(url, port, src_name, dst_name, size); +} + +AbstractChannelPtr ChannelFactory::GetDefRPCChannel(const std::string &src_name, + const std::string &dst_name, + const size_t &size) { + return std::make_shared(src_name, dst_name, size); +} +#endif + +ChannelFactory ChannelFactory::channel_factory_; + +ChannelFactory& GetChannelFactory() { + ChannelFactory &channel_factory = ChannelFactory::channel_factory_; + return channel_factory; +} + +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/channel_factory.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/channel_factory.h new file mode 100644 index 000000000..49af6844f --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/channel_factory.h @@ -0,0 +1,64 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef CORE_CHANNEL_FACTORY_H_ +#define CORE_CHANNEL_FACTORY_H_ + +#include +#include + +#include +#include + +namespace message_infrastructure { + +class ChannelFactory { + public: + ChannelFactory(const ChannelFactory&) = delete; + ChannelFactory(ChannelFactory&&) = delete; + ChannelFactory& operator=(const ChannelFactory&) = delete; + ChannelFactory& operator=(ChannelFactory&&) = delete; + + AbstractChannelPtr GetChannel(const ChannelType &channel_type, + const size_t &size, + const size_t &nbytes, + const std::string &src_name, + const std::string &dst_name); + + AbstractChannelPtr GetTempChannel(const std::string &addr_path); +#if defined(DDS_CHANNEL) + AbstractChannelPtr GetDDSChannel(const std::string &src_name, + const std::string &dst_name, + const std::string &topic_name, + const size_t &size, + const size_t &nbytes, + const DDSTransportType &dds_transfer_type, + const DDSBackendType &dds_backend); +#endif + +#if defined(GRPC_CHANNEL) + AbstractChannelPtr GetRPCChannel(const std::string &url, + const int &port, + const std::string &src_name, + const std::string &dst_name, + const size_t &size); + + AbstractChannelPtr GetDefRPCChannel(const std::string &src_name, + const std::string &dst_name, + const size_t &size); +#endif + + friend ChannelFactory& GetChannelFactory(); + + private: + ~ChannelFactory() = default; + ChannelFactory() = default; + static ChannelFactory channel_factory_; +}; + +ChannelFactory& GetChannelFactory(); + +} // namespace message_infrastructure + +#endif // CORE_CHANNEL_FACTORY_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/common.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/common.h new file mode 100644 index 000000000..f58d161ad --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/common.h @@ -0,0 +1,121 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef CORE_COMMON_H_ +#define CORE_COMMON_H_ + +#include + +#include +#include +#include +#include +#include + +namespace message_infrastructure { + +template +class RecvQueue{ + public: + RecvQueue(const std::string& name, const size_t &size) + : name_(name), size_(size), read_index_(0), write_index_(0), done_(false) { + array_.resize(size_); + } + ~RecvQueue() { + Free(); + } + void Push(T val) { + auto const curr_write_index = write_index_.load(std::memory_order_relaxed); + auto next_write_index = curr_write_index + 1; + if (next_write_index == size_) { + next_write_index = 0; + } + if (next_write_index != read_index_.load(std::memory_order_acquire)) { + array_[curr_write_index] = val; + write_index_.store(next_write_index, std::memory_order_release); + } + } + T Pop(bool block) { + while (block && Empty()) { + helper::Sleep(); + if (done_) + return nullptr; + } + auto const curr_read_index = read_index_.load(std::memory_order_relaxed); + assert(curr_read_index != write_index_.load(std::memory_order_acquire)); + T data_ = array_[curr_read_index]; + auto next_read_index = curr_read_index + 1; + if (next_read_index == size_) { + next_read_index = 0; + } + read_index_.store(next_read_index, std::memory_order_release); + return data_; + } + int AvailableCount() { + auto const curr_read_index = read_index_.load(std::memory_order_acquire); + auto const curr_write_index = write_index_.load(std::memory_order_acquire); + if (curr_read_index == curr_write_index) { + return size_; + } + if (curr_write_index > curr_read_index) { + return size_ - curr_write_index + curr_read_index - 1; + } + return curr_read_index - curr_write_index - 1; + } + T Front() { + while (Empty()) { + helper::Sleep(); + if (done_) + return nullptr; + } + auto curr_read_index = read_index_.load(std::memory_order_acquire); + T ptr = array_[curr_read_index]; + return ptr; + } + bool Empty() { + auto const curr_read_index = read_index_.load(std::memory_order_acquire); + auto const curr_write_index = write_index_.load(std::memory_order_acquire); + return curr_read_index == curr_write_index; + } + void Free() { + if (!Empty()) { + auto const curr_read_index = read_index_.load(std::memory_order_acquire); + auto const curr_write_index = write_index_.load(std::memory_order_acquire); // NOLINT + int max, min; + if (curr_read_index < curr_write_index) { + max = curr_write_index; + min = curr_read_index; + } else { + min = curr_write_index + 1; + max = curr_read_index + 1; + } + for (int i = min; i < max; i++) { + FreeData(array_[i]); + array_[i] = nullptr; + } + read_index_.store(0, std::memory_order_release); + write_index_.store(0, std::memory_order_release); + } + } + bool Probe() { + return !Empty(); + } + void Stop() { + done_ = true; + } + + private: + std::vector array_; + std::atomic read_index_; + std::atomic write_index_; + std::string name_; + size_t size_; + std::atomic_bool done_; + + void FreeData(T data); +}; + +} // namespace message_infrastructure + +#endif // CORE_COMMON_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/message_infrastructure_logging.cc b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/message_infrastructure_logging.cc new file mode 100644 index 000000000..5a544f60b --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/message_infrastructure_logging.cc @@ -0,0 +1,109 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include + +namespace message_infrastructure { + +namespace { + +signed int GetPid() { + return getpid(); +} + +std::string GetTime() { + char buf[MAX_SIZE_LOG_TIME] = {}; + struct timespec ts; + timespec_get(&ts, TIME_UTC); + int end = strftime(buf, sizeof(buf), "%Y-%m-%d.%X", gmtime(&ts.tv_sec)); + snprintf(buf + end, MAX_SIZE_LOG_TIME-end, " %09ld", ts.tv_nsec); + return std::string(buf); +} + +} // namespace + +LogMsg::LogMsg(const std::string &msg_data, + const char *log_file, + const int &log_line, + const char *log_level) + : msg_data_(msg_data), + msg_line_(log_line), + msg_file_(log_file), + msg_level_(log_level) { + msg_time_ = GetTime(); +} + +std::string LogMsg::GetEntireLogMsg(const int &pid) { + std::stringstream buf; + buf << msg_time_ << " "; + buf << msg_level_ << " "; + buf << pid << " "; + buf << msg_file_ << ":"; + buf << msg_line_ << " "; + buf << msg_data_; + return buf.str(); +} + +MessageInfrastructureLog::MessageInfrastructureLog() { + char *log_path = getenv("MSG_LOG_PATH"); + if (log_path == nullptr) { + log_path = getcwd(nullptr, 0); + log_path_ = log_path; + free(log_path); + return; + } + log_path_ = log_path; +} + +// multithread safe +void MessageInfrastructureLog::LogWrite(const LogMsg& msg) { + std::lock_guard lg(log_lock_); + log_queue_.push(msg); + if (log_queue_.size() == MAX_SIZE_LOG) { + WriteDown(); + } +} + +// multithread unsafe +void MessageInfrastructureLog::Clear() { + std::queue().swap(log_queue_); +} + +// multithread unsafe +void MessageInfrastructureLog::WriteDown() { + if (log_queue_.empty()) return; + signed int pid = GetPid(); + std::stringstream log_file_name; + log_file_name << log_path_ << "/" << DEBUG_LOG_MODULE << "_pid_" << pid \ + << "." << DEBUG_LOG_FILE_SUFFIX; + std::fstream log_file; + log_file.open(log_file_name.str(), std::ios::app); + while (!log_queue_.empty()) { + std::string log_str = log_queue_.front().GetEntireLogMsg(pid); + log_file << log_str; + log_queue_.pop(); + } + log_file.close(); +} + +MessageInfrastructureLog::~MessageInfrastructureLog() { + WriteDown(); +} + +void LogClear() { +#if defined(MSG_LOG_FILE_ENABLE) + GetLogInstance()->Clear(); +#endif +} + +MessageInfrastructureLogPtr log_instance; + +MessageInfrastructureLogPtr GetLogInstance() { + if (log_instance == nullptr) { + log_instance = std::make_shared(); + } + return log_instance; +} + +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/message_infrastructure_logging.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/message_infrastructure_logging.h new file mode 100644 index 000000000..85c5cd9d7 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/message_infrastructure_logging.h @@ -0,0 +1,184 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef CORE_MESSAGE_INFRASTRUCTURE_LOGGING_H_ +#define CORE_MESSAGE_INFRASTRUCTURE_LOGGING_H_ + +#include +#include // NOLINT +#include +#include +#include +#include +#include +#include +#include +#include + +#if _WIN32 +#inlcude +#include // _getcwd() +#define getpid() _getpid() +#define getcwd() _getcwd() +#else // __linux__ & __APPLE__ +#include +#include +#endif + +#define MAX_SIZE_LOG (1) +#define MAX_SIZE_LOG_TIME (64) +#define MAX_SIZE_PER_LOG_MSG (1024) + +#define NULL_STRING "" +#define DEBUG_LOG_FILE_SUFFIX "log" +#define LOG_GET_TIME_FAIL "Get log time failed." +#define DEBUG_LOG_MODULE "lava_message_infrastructure" +#define LOG_MSG_SUBSTITUTION "This message was displayed due to " \ + "the failure of the malloc of this log message!" + +// the following macros indicate if the specific log message need to be printed +// except the ERROR log message, ERROR log message will be printed whatever the +// macro value is +#define LOG_MP (1) // log for multiprocessing +#define LOG_ACTOR (1) +#define LOG_LAYER (1) +#define LOG_SMMP (1) // log for shmemport +#define LOG_SKP (1) // log for socketport +#define LOG_DDS (1) // lof for DDS Channel +#define LOG_UTTEST (1) + +#if defined(MSG_LOG_LEVEL) +#elif defined(MSG_LOG_LEVEL_ALL) + #define MSG_LOG_LEVEL (LOG_MASK_DBUG) +#elif defined(MSG_LOG_LEVEL_WARN) + #define MSG_LOG_LEVEL (LOG_MASK_WARN) +#elif defined(MSG_LOG_LEVEL_DUMP) + #define MSG_LOG_LEVEL (LOG_MASK_DUMP) +#elif defined(MSG_LOG_LEVEL_INFO) + #define MSG_LOG_LEVEL (LOG_MASK_INFO) +#else + #define MSG_LOG_LEVEL (LOG_MASK_ERRO) // default +#endif + +#if defined(MSG_LOG_FILE_ENABLE) +#define DEBUG_LOG_PRINT(_level, _fmt, ...) do { \ + std::vector log_data(MAX_SIZE_PER_LOG_MSG); \ + log_data.emplace_back('\0'); \ + int length = std::snprintf(log_data.data(), \ + MAX_SIZE_PER_LOG_MSG, _fmt, ## __VA_ARGS__); \ + log_data.resize(length); \ + log_data.emplace_back('\0'); \ + if (length < 0) { \ + GetLogInstance()->LogWrite(LogMsg(std::string(LOG_MSG_SUBSTITUTION), \ + __FILE__, \ + __LINE__, \ + _level)); \ + } else { \ + GetLogInstance()->LogWrite(LogMsg(std::string(log_data.data()), \ + __FILE__, \ + __LINE__, \ + _level)); \ + } \ + std::printf("%s[%d] ", _level, getpid()); \ + std::printf(_fmt, ## __VA_ARGS__); \ +} while (0) +#else +#define DEBUG_LOG_PRINT(_level, _fmt, ...) do { \ + std::printf("%s[%d]", _level, getpid()); \ + std::printf(_fmt, ## __VA_ARGS__); \ +} while (0) +#endif + +#define LAVA_LOG(_module, _fmt, ...) do { \ + if ((_module) && (MSG_LOG_LEVEL <= LOG_MASK_INFO)) { \ + DEBUG_LOG_PRINT("[CPP INFO]", _fmt, ## __VA_ARGS__); \ + } \ +} while (0) + +#define LAVA_DUMP(_module, _fmt, ...) do { \ + if (_module && (MSG_LOG_LEVEL <= LOG_MASK_DUMP)) { \ + DEBUG_LOG_PRINT("[CPP DUMP]", _fmt, ## __VA_ARGS__); \ + } \ +} while (0) + +#define LAVA_DEBUG(_module, _fmt, ...) do { \ + if (_module && (MSG_LOG_LEVEL <= LOG_MASK_DBUG)) { \ + DEBUG_LOG_PRINT("[CPP DBUG]", _fmt, ## __VA_ARGS__); \ + } \ +} while (0) + +#define LAVA_LOG_WARN(_module, _fmt, ...) do { \ + if (_module && (MSG_LOG_LEVEL <= LOG_MASK_WARN)) { \ + DEBUG_LOG_PRINT("[CPP WARN]", _fmt, ## __VA_ARGS__); \ + } \ +} while (0) + +#define LAVA_LOG_ERR(_fmt, ...) do { \ + DEBUG_LOG_PRINT("[CPP ERRO]", _fmt, ## __VA_ARGS__); \ +} while (0) + +#define LAVA_LOG_FATAL(_fmt, ...) do { \ + DEBUG_LOG_PRINT("[CPP FATAL ERRO]", _fmt, ## __VA_ARGS__); \ + exit(-1); \ +} while (0) + +#define LAVA_ASSERT_INT(result, expectation) do { \ + if (int r = (result) != expectation) { \ + LAVA_LOG_ERR("Assert failed, %d get, %d except. Errno: %d\n", \ + r, 0, errno); \ + exit(-1); \ + } \ +} while (0) + +namespace message_infrastructure { + +enum LogLevel { + LOG_MASK_DBUG, + LOG_MASK_INFO, + LOG_MASK_DUMP, + LOG_MASK_WARN, + LOG_MASK_ERRO +}; + +class LogMsg{ + public: + LogMsg(const std::string &msg_data, + const char *log_file, + const int &log_line, + const char *log_level); + std::string GetEntireLogMsg(const int &pid); + + private: + std::string msg_time_; + std::string msg_data_; + std::string msg_level_; + std::string msg_file_; + int msg_line_ = 0; +}; + +class MessageInfrastructureLog { + public: + MessageInfrastructureLog(); + ~MessageInfrastructureLog(); + void LogWrite(const LogMsg &msg); + void Clear(); + void WriteDown(); + + private: + std::string log_path_; + std::mutex log_lock_; + std::queue log_queue_; +}; + +// MessageInfrastructureLog object should be handled by multiple actors. +// Use std::shared_ptr. +using MessageInfrastructureLogPtr = std::shared_ptr; + +MessageInfrastructureLogPtr GetLogInstance(); + +void LogClear(); + +} // namespace message_infrastructure + +#endif // CORE_MESSAGE_INFRASTRUCTURE_LOGGING_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/multiprocessing.cc b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/multiprocessing.cc new file mode 100644 index 000000000..34310773f --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/multiprocessing.cc @@ -0,0 +1,58 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include +#include +#if defined(GRPC_CHANNEL) +#include +#endif + +namespace message_infrastructure { + +ProcessType MultiProcessing::BuildActor(AbstractActor::TargetFn target_fn) { + AbstractActor::ActorPtr actor = new PosixActor(target_fn); + ProcessType ret = actor->Create(); + actors_.push_back(actor); + return ret; +} + +void MultiProcessing::Stop() { + for (auto actor : actors_) { + actor->Control(ActorCmd::CmdStop); + } + LAVA_LOG(LOG_MP, "Send Stop cmd to Actors\n"); +} + +void MultiProcessing::Cleanup(bool block) { + if (block) { + for (auto actor : actors_) { + actor->Wait(); + } + } + GetSharedMemManagerSingleton().DeleteAllSharedMemory(); +#if defined(GRPC_CHANNEL) + GetGrpcManagerSingleton().Release(); +#endif +} + +void MultiProcessing::CheckActor() { + for (auto actor : actors_) { + LAVA_LOG(LOG_MP, "Actor info: (pid, status):(%d, %d)", + actor->GetPid(), static_cast(actor->GetStatus())); + } +} + +std::vector& MultiProcessing::GetActors() { + return actors_; +} + +MultiProcessing::~MultiProcessing() { + for (auto actor : actors_) { + delete actor; + } + GetSharedMemManagerSingleton().DeleteAllSharedMemory(); +} + +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/multiprocessing.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/multiprocessing.h new file mode 100644 index 000000000..1ae43b5fa --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/multiprocessing.h @@ -0,0 +1,31 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef CORE_MULTIPROCESSING_H_ +#define CORE_MULTIPROCESSING_H_ + +#include +#include +#include +#include + +namespace message_infrastructure { + +class MultiProcessing { + public: + ~MultiProcessing(); + void Stop(); + ProcessType BuildActor(AbstractActor::TargetFn target_fn); + void CheckActor(); + void Cleanup(bool block); + std::vector& GetActors(); + + private: + bool block_; + std::vector actors_; +}; + +} // namespace message_infrastructure + +#endif // CORE_MULTIPROCESSING_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/ports.cc b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/ports.cc new file mode 100644 index 000000000..2e5d08e87 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/ports.cc @@ -0,0 +1,197 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include + +namespace message_infrastructure { + +CppInPort::CppInPort(const RecvPortList &recv_ports) + : AbstractPortImplementation(recv_ports) +{} + +bool CppInPort::Probe() { + return true; +} + +int CppInPortVectorDense::Recv() { + // Todo + return 0; +} + +int CppInPortVectorDense::Peek() { + // Todo + return 0; +} + +int CppInPortVectorSparse::Recv() { + // Todo + return 0; +} + +int CppInPortVectorSparse::Peek() { + // Todo + return 0; +} + +int CppInPortScalarDense::Recv() { + // Todo + return 0; +} + +int CppInPortScalarDense::Peek() { + // Todo + return 0; +} + +int CppInPortScalarSparse::Recv() { + // Todo + return 0; +} + +int CppInPortScalarSparse::Peek() { + // Todo + return 0; +} + +CppOutPort::CppOutPort(const SendPortList &send_ports) + : AbstractPortImplementation(send_ports) +{} + +int CppOutPortVectorDense::Send() { + // Todo + return 0; +} + +int CppOutPortVectorSparse::Send() { + // Todo + return 0; +} + +int CppOutPortScalarDense::Send() { + // Todo + return 0; +} + +int CppOutPortScalarSparse::Send() { + // Todo + return 0; +} + +CppRefPort::CppRefPort(const SendPortList &send_ports, + const RecvPortList &recv_ports) + : AbstractPortImplementation(send_ports, recv_ports) +{} + +int CppRefPort::Wait() { + // Todo + return 0; +} + +int CppRefPortVectorDense::Read() { + // Todo + return 0; +} + +int CppRefPortVectorDense::Write() { + // Todo + return 0; +} + +int CppRefPortVectorSparse::Read() { + // Todo + return 0; +} + +int CppRefPortVectorSparse::Write() { + // Todo + return 0; +} + +int CppRefPortScalarDense::Read() { + // Todo + return 0; +} + +int CppRefPortScalarDense::Write() { + // Todo + return 0; +} + +int CppRefPortScalarSparse::Read() { + // Todo + return 0; +} + +int CppRefPortScalarSparse::Write() { + // Todo + return 0; +} + +CppVarPort::CppVarPort(const std::string &name, + const SendPortList &send_ports, + const RecvPortList &recv_ports) + : name_(name), AbstractPortImplementation(send_ports, recv_ports) +{} + +int CppVarPortVectorDense::Service() { + // Todo + return 0; +} + +int CppVarPortVectorDense::Recv() { + // Todo + return 0; +} + +int CppVarPortVectorDense::Peek() { + // Todo + return 0; +} + +int CppVarPortVectorSparse::Service() { + // Todo + return 0; +} + +int CppVarPortVectorSparse::Recv() { + // Todo + return 0; +} + +int CppVarPortVectorSparse::Peek() { + // Todo + return 0; +} + +int CppVarPortScalarDense::Service() { + // Todo + return 0; +} + +int CppVarPortScalarDense::Recv() { + // Todo + return 0; +} + +int CppVarPortScalarDense::Peek() { + // Todo + return 0; +} + +int CppVarPortScalarSparse::Service() { + // Todo + return 0; +} + +int CppVarPortScalarSparse::Recv() { + // Todo + return 0; +} + +int CppVarPortScalarSparse::Peek() { + // Todo + return 0; +} + +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/ports.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/ports.h new file mode 100644 index 000000000..852d11d56 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/ports.h @@ -0,0 +1,208 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef CORE_PORTS_H_ +#define CORE_PORTS_H_ + +#include + +#include +#include +#include +#include + +namespace message_infrastructure { + +class CppInPort : public AbstractPortImplementation { + public: + explicit CppInPort(const RecvPortList &recv_ports); + virtual ~CppInPort() = default; + bool Probe(); + virtual int Recv() = 0; + virtual int Peek() = 0; +}; + + +class CppInPortVectorDense final : public CppInPort { + public: + using CppInPort::CppInPort; + ~CppInPortVectorDense() override {} + int Recv() override; + int Peek() override; +}; + + +class CppInPortVectorSparse final : public CppInPort { + public: + using CppInPort::CppInPort; + ~CppInPortVectorSparse() override {} + int Recv() override; + int Peek() override; +}; + + +class CppInPortScalarDense final : public CppInPort { + public: + using CppInPort::CppInPort; + ~CppInPortScalarDense() override {} + int Recv() override; + int Peek() override; +}; + + +class CppInPortScalarSparse final : public CppInPort { + public: + using CppInPort::CppInPort; + ~CppInPortScalarSparse() override {} + int Recv() override; + int Peek() override; +}; + + +class CppOutPort : public AbstractPortImplementation { + public: + explicit CppOutPort(const SendPortList &send_ports); + virtual ~CppOutPort() = default; + virtual int Send() = 0; + void Flush() {} +}; + + +class CppOutPortVectorDense final : public CppOutPort { + public: + using CppOutPort::CppOutPort; + ~CppOutPortVectorDense() override {} + int Send() override; +}; + + +class CppOutPortVectorSparse final : public CppOutPort { + public: + using CppOutPort::CppOutPort; + ~CppOutPortVectorSparse() override {} + int Send() override; +}; + + +class CppOutPortScalarDense final : public CppOutPort { + public: + using CppOutPort::CppOutPort; + ~CppOutPortScalarDense() override {} + int Send() override; +}; + + +class CppOutPortScalarSparse final : public CppOutPort { + public: + using CppOutPort::CppOutPort; + ~CppOutPortScalarSparse() override {} + int Send() override; +}; + + +class CppRefPort : public AbstractPortImplementation { + public: + explicit CppRefPort(const SendPortList &send_ports, + const RecvPortList &recv_ports); + virtual ~CppRefPort() = default; + virtual int Read() = 0; + virtual int Write() = 0; + int Wait(); +}; + + +class CppRefPortVectorDense final : public CppRefPort { + public: + using CppRefPort::CppRefPort; + ~CppRefPortVectorDense() override {} + int Read() override; + int Write() override; +}; + + +class CppRefPortVectorSparse final : public CppRefPort { + public: + using CppRefPort::CppRefPort; + ~CppRefPortVectorSparse() override {} + int Read() override; + int Write() override; +}; + + +class CppRefPortScalarDense final : public CppRefPort { + public: + using CppRefPort::CppRefPort; + ~CppRefPortScalarDense() override {} + int Read() override; + int Write() override; +}; + + +class CppRefPortScalarSparse final : public CppRefPort { + public: + using CppRefPort::CppRefPort; + ~CppRefPortScalarSparse() override {} + int Read() override; + int Write() override; +}; + + +class CppVarPort : public AbstractPortImplementation { + public: + explicit CppVarPort(const std::string &name, + const SendPortList &send_ports, + const RecvPortList &recv_ports); + virtual ~CppVarPort() = default; + virtual int Service() = 0; + virtual int Recv() = 0; + virtual int Peek() = 0; + + private: + std::string name_; +}; + + +class CppVarPortVectorDense final : public CppVarPort { + public: + using CppVarPort::CppVarPort; + ~CppVarPortVectorDense() override {} + int Service() override; + int Recv() override; + int Peek() override; +}; + + +class CppVarPortVectorSparse final : public CppVarPort { + public: + using CppVarPort::CppVarPort; + ~CppVarPortVectorSparse() override {} + int Service() override; + int Recv() override; + int Peek() override; +}; + + +class CppVarPortScalarDense final : public CppVarPort { + public: + using CppVarPort::CppVarPort; + ~CppVarPortScalarDense() override {} + int Service() override; + int Recv() override; + int Peek() override; +}; + + +class CppVarPortScalarSparse final : public CppVarPort { + public: + using CppVarPort::CppVarPort; + ~CppVarPortScalarSparse() override {} + int Service() override; + int Recv() override; + int Peek() override; +}; + + +} // namespace message_infrastructure + +#endif // CORE_PORTS_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/utils.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/utils.h new file mode 100644 index 000000000..6a2bf7e19 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/core/utils.h @@ -0,0 +1,209 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef CORE_UTILS_H_ +#define CORE_UTILS_H_ + +#if defined(GRPC_CHANNEL) +#include +#endif +#include +#include +#include // NOLINT +#include // NOLINT + +#if defined(ENABLE_MM_PAUSE) +#include +#endif + +#define MAX_ARRAY_DIMS (5) +#define SLEEP_MS (1) + +#define LAVA_SIZEOF_CHAR (sizeof(char)) +#define LAVA_SIZEOF_UCHAR (LAVA_SIZEOF_CHAR) + +#define LAVA_SIZEOF_BOOL (LAVA_SIZEOF_UCHAR) +#define LAVA_SIZEOF_BYTE (LAVA_SIZEOF_CHAR) +#define LAVA_SIZEOF_UBYTE (LAVA_SIZEOF_UCHAR) +#define LAVA_SIZEOF_SHORT (sizeof(short)) // NOLINT +#define LAVA_SIZEOF_USHORT (sizeof(u_short)) +#define LAVA_SIZEOF_INT (sizeof(int)) +#define LAVA_SIZEOF_UINT (sizeof(uint)) +#define LAVA_SIZEOF_LONG (sizeof(long)) // NOLINT +#define LAVA_SIZEOF_ULONG (sizeof(ulong)) +#define LAVA_SIZEOF_LONGLONG (sizeof(long long)) // NOLINT +#define LAVA_SIZEOF_ULONGLONG (sizeof(ulong long)) // NOLINT +#define LAVA_SIZEOF_FLOAT (sizeof(float)) +#define LAVA_SIZEOF_DOUBLE (sizeof(double)) +#define LAVA_SIZEOF_LONGDOUBLE (sizeof(long double)) +#define LAVA_SIZEOF_NULL (0) +// the length of string is unknown +#define LAVA_SIZEOF_STRING (-1) + +#define SIZEOF(TYPE) (LAVA_SIZEOF_ARRAY[TYPE]) + +namespace message_infrastructure { + +enum class ProcessType { + ErrorProcess = 0, + ParentProcess = 1, + ChildProcess = 2 +}; + +enum class ChannelType { + SHMEMCHANNEL = 0, + RPCCHANNEL = 1, + DDSCHANNEL = 2, + SOCKETCHANNEL = 3, + TEMPCHANNEL = 4 +}; + +enum class METADATA_TYPES : int64_t{ BOOL = 0, + BYTE, UBYTE, + SHORT, USHORT, + INT, UINT, + LONG, ULONG, + LONGLONG, ULONGLONG, + FLOAT, DOUBLE, LONGDOUBLE, + // align the value of STRING to + // NPY_STRING in ndarraytypes.h + STRING = 18 +}; + +static int64_t LAVA_SIZEOF_ARRAY[static_cast(METADATA_TYPES::STRING) + 1] = + { LAVA_SIZEOF_BOOL, + LAVA_SIZEOF_BYTE, + LAVA_SIZEOF_UBYTE, + LAVA_SIZEOF_SHORT, + LAVA_SIZEOF_USHORT, + LAVA_SIZEOF_INT, + LAVA_SIZEOF_UINT, + LAVA_SIZEOF_LONG, + LAVA_SIZEOF_ULONG, + LAVA_SIZEOF_LONGLONG, + LAVA_SIZEOF_ULONGLONG, + LAVA_SIZEOF_FLOAT, + LAVA_SIZEOF_DOUBLE, + LAVA_SIZEOF_LONGDOUBLE, + LAVA_SIZEOF_NULL, + LAVA_SIZEOF_NULL, + LAVA_SIZEOF_NULL, + LAVA_SIZEOF_NULL, + LAVA_SIZEOF_STRING + }; + +struct MetaData { + int64_t nd; + int64_t type; + int64_t elsize; + int64_t total_size; + int64_t dims[MAX_ARRAY_DIMS] = {0}; + int64_t strides[MAX_ARRAY_DIMS] = {0}; + void* mdata; +}; + +// Incase Peek() and Recv() operations of ports will reuse Metadata. +// Use std::shared_ptr. +using MetaDataPtr = std::shared_ptr; +using DataPtr = std::shared_ptr; + +#if defined(GRPC_CHANNEL) +using grpcchannel::GrpcMetaData; +using GrpcMetaDataPtr = std::shared_ptr; +#endif + +inline void GetMetadata(const MetaDataPtr &metadataptr, + void *array, + const int64_t &nd, + const int64_t &dtype, + int64_t *dims) { + if (nd <= 0 || nd > MAX_ARRAY_DIMS) { + LAVA_LOG_ERR("Invalid nd: %ld\n", nd); + return; + } + for (int i = 0; i < nd ; i++) { + metadataptr->dims[i] = dims[i]; + } + int product = 1; + for (int i = 0; i < nd; i++) { + metadataptr->strides[nd - i - 1] = product; + product *= metadataptr->dims[nd - i - 1]; + } + metadataptr->total_size = product; + metadataptr->elsize = SIZEOF(dtype); + metadataptr->type = dtype; + metadataptr->mdata = array; +} + +namespace helper { + +static void Sleep() { +#if defined(ENABLE_MM_PAUSE) + _mm_pause(); +#else + std::this_thread::sleep_for(std::chrono::milliseconds(SLEEP_MS)); +#endif +} + +static void Sleep(int64_t ns) { +#if defined(ENABLE_MM_PAUSE) + _mm_pause(); +#else + std::this_thread::sleep_for(std::chrono::nanoseconds(ns)); +#endif +} + +} // namespace helper + +#if defined(DDS_CHANNEL) +// Default Parameters +// Transport +#define SHM_SEGMENT_SIZE (2 * 1024 * 1024) +#define NON_BLOCKING_SEND (false) +#define UDP_OUT_PORT (0) +#define TCP_PORT 46 +#define TCPv4_IP ("0.0.0.0") +// QOS +#define HEARTBEAT_PERIOD_SECONDS (2) +#define HEARTBEAT_PERIOD_NANOSEC (200 * 1000 * 1000) +// Topic +#define DDS_DATATYPE_NAME "ddsmetadata::msg::dds_::DDSMetaData_" +// Using default_nbytes in DDSChannel +// When user only care about the topic_name and not the nbytes +#define DEFAULT_NBYTES 0 + +enum class DDSTransportType { + DDSSHM = 0, + DDSTCPv4 = 1, + DDSTCPv6 = 2, + DDSUDPv4 = 3, + DDSUDPv6 = 4 +}; + +enum class DDSBackendType { + FASTDDSBackend = 0, + CycloneDDSBackend = 1 +}; + +enum class DDSInitErrorType { + DDSNOERR = 0, + DDSParticipantError = 1, + DDSPublisherError = 2, + DDSSubscriberError = 3, + DDSTopicError = 4, + DDSDataWriterError = 5, + DDSDataReaderError = 6, + DDSTypeParserError = 7 +}; + +#endif + +#if defined(GRPC_CHANNEL) +#define DEFAULT_GRPC_URL "0.0.0.0:" +#define DEFAULT_GRPC_PORT 8000 +#endif + +} // namespace message_infrastructure + +#endif // CORE_UTILS_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/message_infrastructure_py_wrapper.cc b/src/lava/magma/runtime/_c_message_infrastructure/csrc/message_infrastructure_py_wrapper.cc new file mode 100644 index 000000000..c9c3cdac8 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/message_infrastructure_py_wrapper.cc @@ -0,0 +1,223 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +namespace message_infrastructure { + +namespace py = pybind11; + +// Users should be allowed to copy channel objects. +// Use std::shared_ptr. +#if defined(GRPC_CHANNEL) +using GetRPCChannelProxyPtr = std::shared_ptr; +#endif + +#if defined(DDS_CHANNEL) +using GetDDSChannelProxyPtr = std::shared_ptr; +#endif + +PYBIND11_MODULE(MessageInfrastructurePywrapper, m) { + py::class_ (m, "CppMultiProcessing") + .def(py::init<>()) + .def("build_actor", &MultiProcessing::BuildActor) + .def("check_actor", &MultiProcessing::CheckActor) + .def("get_actors", &MultiProcessing::GetActors, + py::return_value_policy::reference) + .def("cleanup", &MultiProcessing::Cleanup) + .def("stop", &MultiProcessing::Stop); + py::enum_ (m, "ProcessType") + .value("ErrorProcess", ProcessType::ErrorProcess) + .value("ChildProcess", ProcessType::ChildProcess) + .value("ParentProcess", ProcessType::ParentProcess) + .export_values(); + py::enum_ (m, "ActorStatus") + .value("StatusError", ActorStatus::StatusError) + .value("StatusRunning", ActorStatus::StatusRunning) + .value("StatusStopped", ActorStatus::StatusStopped) + .value("StatusPaused", ActorStatus::StatusPaused) + .value("StatusTerminated", ActorStatus::StatusTerminated) + .export_values(); + py::enum_ (m, "ActorCmd") + .value("CmdRun", ActorCmd::CmdRun) + .value("CmdStop", ActorCmd::CmdStop) + .value("CmdPause", ActorCmd::CmdPause) + .export_values(); + py::class_ (m, "Actor") + .def("wait", &PosixActor::Wait) + .def("get_status", &PosixActor::GetStatus) + .def("set_stop_fn", &PosixActor::SetStopFn) + .def("pause", [](PosixActor &actor){ + actor.Control(ActorCmd::CmdPause); + }) + .def("start", [](PosixActor &actor){ + actor.Control(ActorCmd::CmdRun); + }) + .def("stop", [](PosixActor &actor){ + actor.Control(ActorCmd::CmdStop); + }) + .def("status_stopped", [](PosixActor &actor){ + return actor.SetStatus(ActorStatus::StatusStopped); + }) + .def("status_running", [](PosixActor &actor){ + return actor.SetStatus(ActorStatus::StatusRunning); + }) + .def("status_paused", [](PosixActor &actor){ + return actor.SetStatus(ActorStatus::StatusPaused); + }) + .def("status_terminated", [](PosixActor &actor){ + return actor.SetStatus(ActorStatus::StatusTerminated); + }) + .def("error", [](PosixActor &actor){ + return actor.SetStatus(ActorStatus::StatusError); + }); + py::enum_ (m, "ChannelType") + .value("SHMEMCHANNEL", ChannelType::SHMEMCHANNEL) + .value("RPCCHANNEL", ChannelType::RPCCHANNEL) + .value("DDSCHANNEL", ChannelType::DDSCHANNEL) + .value("SOCKETCHANNEL", ChannelType::SOCKETCHANNEL) + .export_values(); + py::class_> (m, "AbstractTransferPort") + .def(py::init<>()); + py::class_> (m, "Channel") + .def(py::init()) + .def_property_readonly("src_port", &ChannelProxy::GetSendPort, + py::return_value_policy::reference) + .def_property_readonly("dst_port", &ChannelProxy::GetRecvPort, + py::return_value_policy::reference); + py::class_> + (m, "TempChannel") + .def(py::init()) + .def(py::init<>()) + .def_property_readonly("addr_path", &TempChannelProxy::GetAddrPath) + .def_property_readonly("src_port", &TempChannelProxy::GetSendPort, + py::return_value_policy::reference) + .def_property_readonly("dst_port", &TempChannelProxy::GetRecvPort, + py::return_value_policy::reference); +#if defined(GRPC_CHANNEL) + py::class_ (m, "GetRPCChannel") + .def(py::init()) + .def(py::init()) + .def_property_readonly("src_port", &GetRPCChannelProxy::GetSendPort, + py::return_value_policy::reference) + .def_property_readonly("dst_port", &GetRPCChannelProxy::GetRecvPort, + py::return_value_policy::reference); +#endif + m.def("support_grpc_channel", [](){ +#if defined(GRPC_CHANNEL) + return true; +#else + return false; +#endif + }); + +#if defined(DDS_CHANNEL) + py::enum_ (m, "DDSTransportType") + .value("DDSSHM", DDSTransportType::DDSSHM) + .value("DDSTCPv4", DDSTransportType::DDSTCPv4) + .value("DDSTCPv6", DDSTransportType::DDSTCPv6) + .value("DDSUDPv4", DDSTransportType::DDSUDPv4) + .value("DDSUDPv6", DDSTransportType::DDSUDPv6) + .export_values(); + + py::enum_ (m, "DDSBackendType") + .value("FASTDDSBackend", DDSBackendType::FASTDDSBackend) + .value("CycloneDDSBackend", DDSBackendType::CycloneDDSBackend) + .export_values(); + + py::class_ (m, "GetDDSChannel") + .def(py::init([]( + const std::string& topic_name, + size_t size, + DDSTransportType dds_transfer_type, + DDSBackendType dds_backend) { + std::string src_name = + topic_name+"_src_" + std::to_string(std::rand()); + std::string dst_name = + topic_name+"_dst_" + std::to_string(std::rand()); + return new GetDDSChannelProxy(src_name, dst_name, + topic_name, size, DEFAULT_NBYTES, + dds_transfer_type, dds_backend); + }) + ) + .def(py::init([]( + const std::string& src_name, + const std::string& dst_name, + size_t size, + size_t nbytes, + DDSTransportType dds_transfer_type, + DDSBackendType dds_backend) { + std::string topic_name = + "dds_topic_" + std::to_string(std::rand()); + return new GetDDSChannelProxy(src_name, dst_name, + topic_name, size, nbytes, + dds_transfer_type, dds_backend); + }) + ) + .def_property_readonly("src_port", &GetDDSChannelProxy::GetSendPort, + py::return_value_policy::reference) + .def_property_readonly("dst_port", &GetDDSChannelProxy::GetRecvPort, + py::return_value_policy::reference); + +#endif + + m.def("support_fastdds_channel", [](){ +#if defined(FASTDDS_ENABLE) + return true; +#else + return false; +#endif + }); + + m.def("support_cyclonedds_channel", [](){ +#if defined(CycloneDDS_ENABLE) + return true; +#else + return false; +#endif + }); + + py::class_> (m, "SendPort") + .def(py::init<>()) + .def("get_channel_type", &SendPortProxy::GetChannelType) + .def("start", &SendPortProxy::Start) + .def("probe", &SendPortProxy::Probe) + .def("send", &SendPortProxy::Send) + .def("join", &SendPortProxy::Join) + .def_property_readonly("name", &SendPortProxy::Name) + .def_property_readonly("shape", &SendPortProxy::Shape) + .def_property_readonly("d_type", &SendPortProxy::DType) + .def_property_readonly("size", &SendPortProxy::Size); + py::class_> (m, "RecvPort") + .def(py::init<>()) + .def("get_channel_type", &RecvPortProxy::GetChannelType) + .def("start", &RecvPortProxy::Start) + .def("probe", &RecvPortProxy::Probe) + .def("recv", &RecvPortProxy::Recv) + .def("peek", &RecvPortProxy::Peek) + .def("join", &RecvPortProxy::Join) + .def_property_readonly("name", &RecvPortProxy::Name) + .def_property_readonly("shape", &RecvPortProxy::Shape) + .def_property_readonly("d_type", &RecvPortProxy::DType) + .def_property_readonly("size", &RecvPortProxy::Size); + py::class_> (m, "CPPSelector") + .def(py::init<>()) + .def("select", &Selector::Select); +} + +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/port_proxy.cc b/src/lava/magma/runtime/_c_message_infrastructure/csrc/port_proxy.cc new file mode 100644 index 000000000..a46429797 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/port_proxy.cc @@ -0,0 +1,228 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#define NUMPY_CORE_INCLUDE_NUMPY_NPY_1_7_DEPRECATED_API_H_ +// to solve the warning "Using deprecated NumPy API, +// disable it with " "#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION" + +#include +#include +#include +#include + +namespace message_infrastructure { + +namespace py = pybind11; + +#if defined(GRPC_CHANNEL) +DataPtr GrpcMDataFromObject_(py::object* object) { + PyObject *obj = object->ptr(); + LAVA_LOG(LOG_LAYER, "start GrpcMDataFromObject\n"); + if (!PyArray_Check(obj)) { + LAVA_LOG_FATAL("The Object is not array tp is %s\n", Py_TYPE(obj)->tp_name); + exit(-1); + } + LAVA_LOG(LOG_LAYER, "check obj achieved\n"); + auto array = reinterpret_cast (obj); + if (!PyArray_ISWRITEABLE(array)) { + LAVA_LOG(LOG_LAYER, "The array is not writeable\n"); + } + int32_t ndim = PyArray_NDIM(array); + auto dims = PyArray_DIMS(array); + auto strides = PyArray_STRIDES(array); + void* data_ptr = PyArray_DATA(array); + auto dtype = array->descr->type_num; + auto element_size_in_bytes = PyArray_ITEMSIZE(array); + auto tsize = PyArray_SIZE(array); + // set grpcdata + GrpcMetaDataPtr grpcdata = std::make_shared(); + grpcdata->set_nd(ndim); + grpcdata->set_type(dtype); + grpcdata->set_elsize(element_size_in_bytes); + grpcdata->set_total_size(tsize); + for (int i = 0; i < ndim; i++) { + grpcdata->add_dims(dims[i]); + grpcdata->add_strides(strides[i]/element_size_in_bytes); + if (strides[i] % element_size_in_bytes != 0) { + LAVA_LOG_FATAL("Numpy array stride not a multiple of element bytes\n"); + } + } + char* data = reinterpret_cast(data_ptr); + grpcdata->set_value(data, element_size_in_bytes*tsize); + return grpcdata; +} +#endif + +DataPtr MDataFromObject_(py::object* object) { + PyObject *obj = object->ptr(); + LAVA_LOG(LOG_LAYER, "start MDataFromObject\n"); + if (!PyArray_Check(obj)) { + LAVA_LOG_FATAL("The Object is not array tp is %s\n", Py_TYPE(obj)->tp_name); + } + LAVA_LOG(LOG_LAYER, "check obj achieved\n"); + + auto array = reinterpret_cast (obj); + if (!PyArray_ISWRITEABLE(array)) { + LAVA_LOG(LOG_LAYER, "The array is not writeable\n"); + } + + // var from numpy + int32_t ndim = PyArray_NDIM(array); + auto dims = PyArray_DIMS(array); + auto strides = PyArray_STRIDES(array); + void* data_ptr = PyArray_DATA(array); + // auto dtype = PyArray_Type(array); // no work + auto dtype = array->descr->type_num; + auto element_size_in_bytes = PyArray_ITEMSIZE(array); + auto tsize = PyArray_SIZE(array); + // set metadata + MetaDataPtr metadata = std::make_shared(); + metadata->nd = ndim; + for (int i = 0; i < ndim; i++) { + metadata->dims[i] = dims[i]; + metadata->strides[i] = strides[i]/element_size_in_bytes; + if (strides[i] % element_size_in_bytes != 0) { + LAVA_LOG_ERR("Numpy array stride not a multiple of element bytes\n"); + } + } + metadata->type = dtype; + metadata->mdata = data_ptr; + metadata->elsize = element_size_in_bytes; + metadata->total_size = tsize; + return metadata; +} + +void MetaDataDump(MetaDataPtr metadata) { + int64_t *dims = metadata->dims; + int64_t *strides = metadata->strides; + LAVA_DUMP(LOG_LAYER, "MetaData Info:\n" + "(nd, type, elsize): (%ld, %ld, %ld)\n" + "total_size: %ld\n" + "dims:[%ld, %ld, %ld, %ld, %ld]\n" + "strides:[%ld, %ld, %ld, %ld, %ld]\n", + metadata->nd, + metadata->type, + metadata->elsize, + metadata->total_size, + dims[0], dims[1], dims[2], dims[3], dims[4], + strides[0], strides[1], strides[2], strides[3], strides[4]); +} + +py::object PortProxy::DType() { + return d_type_; +} + +py::tuple PortProxy::Shape() { + return shape_; +} + +ChannelType SendPortProxy::GetChannelType() { + return channel_type_; +} +void SendPortProxy::Start() { + send_port_->Start(); +} +bool SendPortProxy::Probe() { + return send_port_->Probe(); +} +void SendPortProxy::Send(py::object* object) { + DataPtr data = DataFromObject_(object); + send_port_->Send(data); +} +void SendPortProxy::Join() { + send_port_->Join(); +} +std::string SendPortProxy::Name() { + return send_port_->Name(); +} +size_t SendPortProxy::Size() { + return send_port_->Size(); +} + +ChannelType RecvPortProxy::GetChannelType() { + return channel_type_; +} +void RecvPortProxy::Start() { + recv_port_->Start(); +} +bool RecvPortProxy::Probe() { + return recv_port_->Probe(); +} +py::object RecvPortProxy::Recv() { + MetaDataPtr metadata = recv_port_->Recv(); + return MDataToObject_(metadata); +} +void RecvPortProxy::Join() { + recv_port_->Join(); +} +py::object RecvPortProxy::Peek() { + MetaDataPtr metadata = recv_port_->Peek(); + return MDataToObject_(metadata); +} +std::string RecvPortProxy::Name() { + return recv_port_->Name(); +} +size_t RecvPortProxy::Size() { + return recv_port_->Size(); +} + +int trick() { + // to solve the warning "converting to non-pointer type 'int' + // from NULL [-Wconversion-null] import_array()" + _import_array(); + return 0; +} + +const int tricky_var = trick(); + +DataPtr SendPortProxy::DataFromObject_(py::object* object) { +#if defined(GRPC_CHANNEL) + if (channel_type_== ChannelType::RPCCHANNEL) { + return GrpcMDataFromObject_(object); + } +#endif + return MDataFromObject_(object); +} + +py::object RecvPortProxy::MDataToObject_(MetaDataPtr metadata) { + if (metadata == nullptr) + return py::cast(0); + + std::vector dims(metadata->nd); + std::vector strides(metadata->nd); + + for (int i = 0; i < metadata->nd; i++) { + dims[i] = metadata->dims[i]; + strides[i] = metadata->strides[i] * metadata->elsize; + } + + PyObject *array = PyArray_New( + &PyArray_Type, + metadata->nd, + dims.data(), + metadata->type, + strides.data(), + metadata->mdata, + metadata->elsize, + NPY_ARRAY_ALIGNED | NPY_ARRAY_WRITEABLE, + nullptr); + + if (!array) + return py::cast(0); + LAVA_DEBUG(LOG_LAYER, "Set PyObject capsule, mdata: %p\n", metadata->mdata); + PyObject *capsule = PyCapsule_New(metadata->mdata, nullptr, + [](PyObject *capsule) { + void *memory = PyCapsule_GetPointer(capsule, nullptr); + LAVA_DEBUG(LOG_LAYER, "PyObject cleaned, free memory: %p.\n", memory); + free(memory); + LAVA_DEBUG(LOG_LAYER, "memory has been released\n");}); + LAVA_ASSERT_INT(nullptr == capsule, 0); + LAVA_ASSERT_INT(PyArray_SetBaseObject( + reinterpret_cast(array), + capsule), 0); + return py::reinterpret_steal(array); +} + + +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/csrc/port_proxy.h b/src/lava/magma/runtime/_c_message_infrastructure/csrc/port_proxy.h new file mode 100644 index 000000000..b5547b166 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/csrc/port_proxy.h @@ -0,0 +1,110 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef PORT_PROXY_H_ +#define PORT_PROXY_H_ + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace message_infrastructure { + +namespace py = pybind11; + +class PortProxy { + public: + PortProxy() {} + PortProxy(py::tuple shape, py::object d_type) : + shape_(shape), d_type_(d_type) {} + py::object DType(); + py::tuple Shape(); + private: + py::object d_type_; + py::tuple shape_; +}; + +class SendPortProxy : public PortProxy { + public: + SendPortProxy() {} + SendPortProxy(ChannelType channel_type, + AbstractSendPortPtr send_port, + py::tuple shape = py::make_tuple(), + py::object type = py::none()) : + PortProxy(shape, type), + channel_type_(channel_type), + send_port_(send_port) {} + ChannelType GetChannelType(); + void Start(); + bool Probe(); + void Send(py::object* object); + void Join(); + std::string Name(); + size_t Size(); + + private: + DataPtr DataFromObject_(py::object* object); + ChannelType channel_type_; + AbstractSendPortPtr send_port_; +}; + + +class RecvPortProxy : public PortProxy { + public: + RecvPortProxy() {} + RecvPortProxy(ChannelType channel_type, + AbstractRecvPortPtr recv_port, + py::tuple shape = py::make_tuple(), + py::object type = py::none()) : + PortProxy(shape, type), + channel_type_(channel_type), + recv_port_(recv_port) {} + + ChannelType GetChannelType(); + void Start(); + bool Probe(); + py::object Recv(); + void Join(); + py::object Peek(); + std::string Name(); + size_t Size(); + + private: + py::object MDataToObject_(MetaDataPtr metadata); + ChannelType channel_type_; + AbstractRecvPortPtr recv_port_; +}; + +// Users should be allowed to copy port objects. +// Use std::shared_ptr. +using SendPortProxyPtr = std::shared_ptr; +using RecvPortProxyPtr = std::shared_ptr; +using SendPortProxyList = std::vector; +using RecvPortProxyList = std::vector; + + +class Selector { + public: + pybind11::object Select(std::vector> *args, const int64_t sleep_ns) { + while (true) { + for (auto it = args->begin(); it != args->end(); ++it) { + if (std::get<0>(*it)->Probe()) { + return std::get<1>(*it)(); + } + } + helper::Sleep(sleep_ns); + } + } +}; + +} // namespace message_infrastructure + +#endif // PORT_PROXY_H_ diff --git a/src/lava/magma/runtime/_c_message_infrastructure/examples/c_pingpong/CMakeLists.txt b/src/lava/magma/runtime/_c_message_infrastructure/examples/c_pingpong/CMakeLists.txt new file mode 100644 index 000000000..6d0416e09 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/examples/c_pingpong/CMakeLists.txt @@ -0,0 +1,15 @@ +include_directories(../../csrc) + +add_executable(cprocess + "cprocess.cc" +) + +target_link_libraries(cprocess PRIVATE + message_infrastructure +) + + +set_target_properties(cprocess + PROPERTIES + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}" +) \ No newline at end of file diff --git a/src/lava/magma/runtime/_c_message_infrastructure/examples/c_pingpong/README.md b/src/lava/magma/runtime/_c_message_infrastructure/examples/c_pingpong/README.md new file mode 100644 index 000000000..5a5968c89 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/examples/c_pingpong/README.md @@ -0,0 +1,25 @@ +# C pingpong + +## Run instructions +One needs to run the `p.py` first, start the `./cprocess` binary once seeing the prompt and hit the enter as indicated. + +Two args can be given as the socket file names. + +```bash +# p.py +$ python3 p.py c2py py2c + +# cprocess, in another terminal window +$ ./cprocess c2py py2c +``` + +## Notes on current TempChennel: +TempChannel uses socket file. + +The Recv port will bind the socket file in initialization, listen in `start()` and accept in `recv()`. After established a connection, the port closes it immediately after reading from the socket. + +The Send port will connect to the recv port in initialization and write to the socket in send(). + +Therefore, +1. The send port can only be initialized after corresponding Recv port called `start()` +2. In each round, the send port is used one-off. One needs to create a new `TempChannel()` and get the send port from it each time. (The send port will be initialized when accessing the .dst_port property at the first time) diff --git a/src/lava/magma/runtime/_c_message_infrastructure/examples/c_pingpong/cprocess.cc b/src/lava/magma/runtime/_c_message_infrastructure/examples/c_pingpong/cprocess.cc new file mode 100644 index 000000000..a8a2377db --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/examples/c_pingpong/cprocess.cc @@ -0,0 +1,44 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include +#include +#include +#include + +using namespace message_infrastructure; // NOLINT + +int main(int argc, char *argv[]) { + char *c2py = (argc >= 2) ? argv[1] : const_cast("./c2py"); + char *py2c = (argc >= 3) ? argv[2] : const_cast("./py2c"); + + std::cout << "socket files: " << c2py << " " << py2c << "\n"; + + ChannelFactory &channel_factory = GetChannelFactory(); + + AbstractChannelPtr ch = channel_factory.GetTempChannel(py2c); + AbstractRecvPortPtr rc = ch->GetRecvPort(); + + // order matters + rc->Start(); + + for (uint _ = 0; _ < 10; ++_) { + std::cout << "receiving\n"; + MetaDataPtr recvd = rc->Recv(); + std::cout << "received from py, total size: " + << recvd->total_size + << "\n"; + + AbstractChannelPtr ch2 = channel_factory.GetTempChannel(c2py); + AbstractSendPortPtr sd = ch2->GetSendPort(); + sd->Start(); + sd->Send(recvd); + sd->Join(); + } + + rc->Join(); + + return 0; +} diff --git a/src/lava/magma/runtime/_c_message_infrastructure/examples/c_pingpong/p.py b/src/lava/magma/runtime/_c_message_infrastructure/examples/c_pingpong/p.py new file mode 100644 index 000000000..519a6f288 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/examples/c_pingpong/p.py @@ -0,0 +1,67 @@ +import os +import sys +import numpy as np + +from lava.magma.runtime.message_infrastructure. \ + MessageInfrastructurePywrapper import ( + TempChannel + ) + + +# float equal +def f_eq(a, b): + return abs(a - b) < 0.001 + + +def soc_names_from_args(): + # default file names + C2PY = "./c2py" + PY2C = "./py2c" + + socket_file_names = [C2PY, PY2C] + filename_args = sys.argv[1:3] + if len(filename_args) == 1: + socket_file_names[0] = filename_args[0] + if len(filename_args) == 2: + socket_file_names = filename_args + + return socket_file_names + + +def main(): + c2py, py2c = soc_names_from_args() + + if os.path.exists(c2py): + os.remove(c2py) + if os.path.exists(py2c): + os.remove(py2c) + + # order matters + ch2 = TempChannel(c2py) + rc = ch2.dst_port + rc.start() + + input("Start the c process, hit enter when you see *receiving*") + + for i in range(10): + # send port is one-off + ch = TempChannel(py2c) + sd = ch.src_port + sd.start() + + print("round ", i) + rands = np.array([np.random.random() * 100 for __ in range(10)]) # noqa + print("Sending array to C: ", rands) + sd.send(rands) + + rands2 = rc.recv() + print("Got array from C: ", rands2) + + print("Correctness: ", all([f_eq(x, y) for x, y in zip(rands, rands2)])) # noqa + print("========================================") + sd.join() + rc.join() + + +if __name__ == "__main__": + main() diff --git a/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/DDSMetaData/CMakeLists.txt b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/DDSMetaData/CMakeLists.txt new file mode 100644 index 000000000..18dd8f05e --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/DDSMetaData/CMakeLists.txt @@ -0,0 +1,41 @@ +cmake_minimum_required(VERSION 3.5) +project(ddsmetadata) + +# Default to C99 +if(NOT CMAKE_C_STANDARD) + set(CMAKE_C_STANDARD 99) +endif() + +# Default to C++14 +if(NOT CMAKE_CXX_STANDARD) + set(CMAKE_CXX_STANDARD 17) +endif() + +if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") + add_compile_options(-Wall -Wextra -Wpedantic) +endif() + +# find dependencies +find_package(ament_cmake REQUIRED) +# uncomment the following section in order to fill in +# further dependencies manually. +# find_package( REQUIRED) +find_package(rosidl_default_generators REQUIRED) + +rosidl_generate_interfaces(${PROJECT_NAME} + "msg/DDSMetaData.msg" +) +# ament_export_dependencies(rosidl_default_runtime) + +if(BUILD_TESTING) + find_package(ament_lint_auto REQUIRED) + # the following line skips the linter which checks for copyrights + # uncomment the line when a copyright and license is not present in all source files + #set(ament_cmake_copyright_FOUND TRUE) + # the following line skips cpplint (only works in a git repo) + # uncomment the line when this package is not in a git repo + #set(ament_cmake_cpplint_FOUND TRUE) + ament_lint_auto_find_test_dependencies() +endif() + +ament_package() \ No newline at end of file diff --git a/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/DDSMetaData/msg/DDSMetaData.msg b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/DDSMetaData/msg/DDSMetaData.msg new file mode 100644 index 000000000..5e72e3c9b --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/DDSMetaData/msg/DDSMetaData.msg @@ -0,0 +1,7 @@ +int64 nd +int64 type +int64 elsize +int64 total_size +int64[5] dims +int64[5] strides +char[] mdata \ No newline at end of file diff --git a/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/DDSMetaData/package.xml b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/DDSMetaData/package.xml new file mode 100644 index 000000000..166669f9a --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/DDSMetaData/package.xml @@ -0,0 +1,24 @@ + + + + ddsmetadata + 0.0.0 + This package is for henerating ddsmetadata for msg_lib dds port talking with ROS2 node. + root + BSD-3-Clause + + ament_cmake + + ament_lint_auto + ament_lint_common + + rosidl_default_generators + + rosidl_default_runtime + + rosidl_interface_packages + + + ament_cmake + + diff --git a/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/README.md b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/README.md new file mode 100644 index 000000000..58cd6bbf8 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/README.md @@ -0,0 +1,196 @@ +# DDS Talking with ROS2 Example + +This example is for testing msg_lib DDS port talking with ROS2 node. The implementation of DDS port in msg_lib does not use any ROS2 libraries. +This example is using [Fast-DDS](https://github.com/eProsima/Fast-DDS/), [CycloneDDS](https://projects.eclipse.org/projects/iot.cyclonedds) and [ROS2 Foxy](https://docs.ros.org/en/foxy/index.html). + +The example contains 2 main parts, ROS2 CPP package and Python package. + +## Installation of ROS2 +Please follow [ROS2 Foxy Installation](https://docs.ros.org/en/foxy/Installation.html) to install ROS2 first. + +The defualt middleware that ROS2 uses is [Fast-RTPS](https://fast-dds.docs.eprosima.com/en/v1.7.0/) and the default DDS implementation is [eProsima’s Fast DDS](https://github.com/eProsima/Fast-DDS/). If you want to test and valid msg_lib Fast-DDS port talking with Ros2 node, so please just use the default middleware and DDS version of ROS2. + +Or if you want to test msg_lib CycloneDDS port talking with ROS2 node, please follow the [guide](https://docs.ros.org/en/foxy/Installation/DDS-Implementations/Working-with-Eclipse-CycloneDDS.html) to install and enable CycloneDDS rmw for ROS2. + +## Build DDSMetadata Package in ROS2 Workspace +As this example need to transfer DDSMetadata type data between ROS2 node and DDS port, DDSMetadata package needs to be built in ROS2 workspace first. Please follow the [guide](https://docs.ros.org/en/foxy/Tutorials/Beginner-Client-Libraries/Custom-ROS2-Interfaces.html) to build the DDSMetadata Package. + +We have already prepared `DDSMetaData` ROS2 package for the message communication structure between ROS2 and msglib. User can build `DDSMetaData` package by the commands: +``` +# open your ros2 build folder and create the ROS2 environment +$ . ~/ros2_foxy/install/local_setup.bash + +# build the DDSMetaData ROS2 Package +$ cd /src/lava/magma/runtime/_c_message_infrastructure/examples/ros2 +$ colcon build --packages-select ddsmetadata +``` + +## Build ROS2 Example Package in ROS2 Workspace +This example also provides the ROS2 package to communicate with DDS port in `ros_talk_with_dds_cpp` and `ros_talk_with_dds_py` folder. You can follow the commands here to build these packages +``` +# import the DDSMetaData Package +$ . install/local_setup.bash + +# build ros_talk_with_dds_cpp or ros_talk_with_dds_py Package +$ colcon build --packages-select ros_talk_with_dds_cpp +# or +$ colcon build --packages-select ros_talk_with_dds_py +``` + +For the detail please follow the [guide](https://docs.ros.org/en/foxy/Tutorials/Beginner-Client-Libraries/Writing-A-Simple-Cpp-Publisher-And-Subscriber.html) to build the cpp code in the folder as a ROS2 package. And for the python code, please follow this [guide](https://docs.ros.org/en/foxy/Tutorials/Beginner-Client-Libraries/Writing-A-Simple-Py-Publisher-And-Subscriber.html). + +## Build the DDS Example code +Please follow the `README.md` in path `src/lava/magma/runtime/_c_message_infrastructure/` to build the message infrastrucure library and initialize the environment. + +Note : Please change cmake args with the options to choose to build the project with FASTDDS or CycloneDDS. +``` +$ export CMAKE_ARGS="-DDDS_CHANNEL=ON -DFASTDDS_ENABLE=ON -DCMAKE_BUILD_TYPE=Debug" +``` +or +``` +$ export CMAKE_ARGS="-DDDS_CHANNEL=ON -DCycloneDDS=ON -DCMAKE_BUILD_TYPE=Debug" +``` +Then run +``` +$ cd +$ poetry install +``` +The DDS example binary will be built in /build/test. (`test_fastdds_from_ros` and `test_fastdds_to_ros`) + +## Running the Example withs FASTDDS +Please open 2 terminals. One is for running ROS2 code and the other is for running DDS port. +### 1st Terminal to Run ROS2 Node +1. Navigate to your ROS2 workspace folder. + ``` + $ cd /src/lava/magma/runtime/_c_message_infrastructure/examples/ros2 + ``` +2. Intialize ROS2 environment. + ``` + $ source /opt/ros/foxy/setup.bash + $ . install/local_setup.bash + ``` +3. Enable Qos configuration in `profile.xml`. [Configuring Fast DDS in ROS 2](https://fast-dds.docs.eprosima.com/en/latest/fastdds/ros2/ros2_configure.html) could be a reference for understanding the QOS configuration for ROS2. + + Run the commands, + ``` + $ export FASTRTPS_DEFAULT_PROFILES_FILE=/examples/ros2/profile.xml + + $ export RMW_FASTRTPS_USE_QOS_FROM_XML=1 + ``` + +4. Users could choose to run CPP or Python ROS2 package. + + (1) Run the `publisher`/`subscriber` node of `ros_talk_with_dds_cpp` ROS2 package. + ``` + $ ros2 run ros_talk_with_dds_cpp ros_pub + ``` + or + ``` + $ ros2 run ros_talk_with_dds_cpp ros_sub + ``` + (2) Run the `publisher`/`subscriber` node of `ros_talk_with_dds_py` ROS2 package. + ``` + $ ros2 run ros_talk_with_dds_py ros_pub + ``` + or + ``` + $ ros2 run ros_talk_with_dds_py ros_sub + ``` +### 2nd Terminal to Run FASTDDS Port +#### Python Example +Users could use Python test files to valid the function of this example. +1. `test_fastdds_to_ros.py` is corresponding to `subscriber` nodes of `ros_talk_with_dds_cpp` and `ros_talk_with_dds_py` ROS2 package. Please run the command: + ``` + $ python test_fastdds_to_ros.py + ``` +2. `test_fastdds_from_ros.py` is corresponding to `publisher` ROS2 node. Please run the command: + ``` + $ python test_fastdds_from_ros.py + ``` +#### CPP Example +To enbale CPP example, when users are processing the steps in 'Build the DDS Example code' part, the option for build the CPP test example need to be set on, just as, +``` +$ cmake .. -DDDS_CHANNEL=ON -DFASTDDS_ENABLE=ON -DCMAKE_BUILD_TYPE=Debug +``` +Then the CPP example test binaries will be generated in the folder, +`/build/test/` +and the names for the tests are `test_fastdds_to_ros` and `test_fastdds_from_ros`. +1. Navigate to the msg_lib build folder, + ``` + $ cd /build/test/ + ``` +2. `test_fastdds_to_ros` is corresponding to `subscriber` node of ROS2 package. Please run the command: + ``` + $ ./test_fastdds_to_ros + ``` +3. `test_fastdds_from_ros` is corresponding to `subscriber` node of ROS2 package. Please run the command: + ``` + $ ./test_fastdds_from_ros + ``` + +## Running the Example withs CycloneDDS +Please open 2 terminals. One is for running ROS2 code and the other is for running DDS port. +### 1st Terminal to Run ROS2 Node +1. Navigate to your ROS2 workspace folder. + ``` + $ cd /src/lava/magma/runtime/_c_message_infrastructure/examples/ros2 + ``` +2. Intialize ROS2 environment. + ``` + $ source /opt/ros/foxy/setup.bash + $ . install/local_setup.bash + ``` +3. Enable Cyclone middleware of ROS2 to make ROS2 communication through CycloneDDS. Make sure you have successfully install Cyclone middleware for ROS2. Then run the command in the terminal, + ``` + $ export RMW_IMPLEMENTATION=rmw_cyclonedds_cpp + ``` + +4. Users could choose to run CPP or Python ROS2 package. + + (1) Run the `publisher`/`subscriber` node of `ros_talk_with_dds_cpp` ROS2 package. + ``` + $ ros2 run ros_talk_with_dds_cpp ros_pub + ``` + or + ``` + $ ros2 run ros_talk_with_dds_cpp ros_sub + ``` + (2) Run the `publisher`/`subscriber` node of `ros_talk_with_dds_py` ROS2 package. + ``` + $ ros2 run ros_talk_with_dds_py ros_pub + ``` + or + ``` + $ ros2 run ros_talk_with_dds_py ros_sub + ``` +### 2nd Terminal to Run CycloneDDS Port +#### Python Example +Users could use Python test files to valid the function of this example. +1. `test_cyclonedds_to_ros.py` is corresponding to `subscriber` nodes of `ros_talk_with_dds_cpp` and `ros_talk_with_dds_py` ROS2 package. Please run the command: + ``` + $ python test_cyclonedds_to_ros.py + ``` +2. `test_cyclonedds_from_ros.py` is corresponding to `publisher` ROS2 node. Please run the command: + ``` + $ python test_cyclonedds_from_ros.py + ``` +#### CPP Example +To enbale CPP example, when users are processing the steps in 'Build the DDS Example code' part, the option for build the CPP test example need to be set on, just as, +``` +$ cmake .. -DDDS_CHANNEL=ON -DCycloneDDS=ON -DCMAKE_BUILD_TYPE=Debug +``` +Then the CPP example test binaries will be generated in the folder, +`/build/test/` +and the names for the tests are `test_cyclonedds_to_ros` and `test_cyclonedds_from_ros`. +1. Navigate to the msg_lib build folder, + ``` + $ cd /build/test/ + ``` +2. `test_cyclonedds_to_ros` is corresponding to `subscriber` node of ROS2 package. Please run the command: + ``` + $ ./test_cyclonedds_to_ros + ``` +3. `test_cyclonedds_from_ros` is corresponding to `subscriber` node of ROS2 package. Please run the command: + ``` + $ ./test_cyclonedds_from_ros + ``` diff --git a/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/profile.xml b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/profile.xml new file mode 100644 index 000000000..881419084 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/profile.xml @@ -0,0 +1,62 @@ + + + + + TransportUDPv4 + UDPv4 + false + 0 + + + + TransportSHM + SHM + 2097152 + + + + + + profile_for_ros2_context + + TransportUDPv4 + + false + + + + + + + KEEP_ALL + 100 + + + + + ASYNCHRONOUS + + + RELIABLE + + + PREALLOCATED_WITH_REALLOC + + + + + KEEP_ALL + 100 + + + + + TRANSIENT_LOCAL + + + RELIABLE + + + PREALLOCATED_WITH_REALLOC + + \ No newline at end of file diff --git a/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_cpp/CMakeLists.txt b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_cpp/CMakeLists.txt new file mode 100644 index 000000000..dd7e17e43 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_cpp/CMakeLists.txt @@ -0,0 +1,30 @@ +cmake_minimum_required(VERSION 3.5) +project(ros_talk_with_dds_cpp) + +# Default to C++14 +if(NOT CMAKE_CXX_STANDARD) + set(CMAKE_CXX_STANDARD 14) +endif() + +if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") + add_compile_options(-Wall -Wextra -Wpedantic) +endif() + +find_package(ament_cmake REQUIRED) +find_package(rclcpp REQUIRED) +find_package(ddsmetadata REQUIRED) + +include_directories(include) + +add_executable(ros_pub src/publisher.cc) +ament_target_dependencies(ros_pub rclcpp ddsmetadata) + +add_executable(ros_sub src/subscriber.cc) +ament_target_dependencies(ros_sub rclcpp ddsmetadata) + +install(TARGETS + ros_pub + ros_sub + DESTINATION lib/${PROJECT_NAME}) + +ament_package() diff --git a/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_cpp/package.xml b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_cpp/package.xml new file mode 100644 index 000000000..9ef9120b0 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_cpp/package.xml @@ -0,0 +1,18 @@ + + + + ros_talk_with_dds_cpp + 0.0.0 + This package is for testing msg_lib dds port talking with ROS2 CPP node. + root + BSD-3-Clause + + ament_cmake + + ament_lint_auto + ament_lint_common + + ament_cmake + rclcpp + + diff --git a/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_cpp/src/publisher.cc b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_cpp/src/publisher.cc new file mode 100644 index 000000000..ef57287c4 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_cpp/src/publisher.cc @@ -0,0 +1,80 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include // NOLINT +#include + +#include "rclcpp/rclcpp.hpp" +#include "ddsmetadata/msg/dds_meta_data.hpp" + +using namespace std::chrono_literals; // NOLINT + +#define MAX_ARRAY_DIMS (5) + +struct MetaData { + int64_t nd; + int64_t type; + int64_t elsize; + int64_t total_size; + int64_t dims[MAX_ARRAY_DIMS] = {0}; + int64_t strides[MAX_ARRAY_DIMS] = {0}; + void* mdata; +}; + +using MetaDataPtr = std::shared_ptr; + +class MinimalPublisher : public rclcpp::Node { + public: + MinimalPublisher() + : Node("minimal_publisher"), count_(0), + metadata(std::make_shared()) { + metadata->nd = 1; + metadata->type = 7; + metadata->elsize = 8; + metadata->total_size = 1; + metadata->dims[0] = 1; + metadata->strides[0] = 1; + metadata->mdata = reinterpret_cast (malloc(sizeof(int64_t))); + *reinterpret_cast(metadata->mdata) = 0; + + publisher_ = rclcpp::Node::create_publisher( + "dds_topic", + rclcpp::SystemDefaultsQoS()); + timer_ = rclcpp::Node::create_wall_timer( + 500ms, std::bind(&MinimalPublisher::timer_callback, this)); + } + + private: + void timer_callback() { + auto message = ddsmetadata::msg::DDSMetaData(); + message.nd = metadata->nd; + message.type = metadata->type; + message.elsize = metadata->elsize; + message.total_size = metadata->total_size; + for (int i = 0; i < MAX_ARRAY_DIMS; i++) { + message.dims[i] = metadata->dims[i]; + message.strides[i] = metadata->strides[i]; + } + *reinterpret_cast(metadata->mdata) = count_; + message.mdata = std::vector( + reinterpret_cast(metadata->mdata), + reinterpret_cast(metadata->mdata) + + metadata->elsize * metadata->total_size); + + RCLCPP_INFO(rclcpp::Node::get_logger(), "ROS2 publishing: '%d'", count_++); + publisher_->publish(message); + } + rclcpp::TimerBase::SharedPtr timer_; + rclcpp::Publisher::SharedPtr publisher_; + size_t count_; + MetaDataPtr metadata; + rmw_publisher_t publisher; +}; + +int main(int argc, char * argv[]) { + rclcpp::init(argc, argv); + rclcpp::spin(std::make_shared()); + rclcpp::shutdown(); + return 0; +} diff --git a/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_cpp/src/subscriber.cc b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_cpp/src/subscriber.cc new file mode 100644 index 000000000..f29d58b19 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_cpp/src/subscriber.cc @@ -0,0 +1,42 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include + +#include "rclcpp/rclcpp.hpp" +#include "ddsmetadata/msg/dds_meta_data.hpp" +using std::placeholders::_1; + +class MinimalSubscriber : public rclcpp::Node { + public: + MinimalSubscriber() + : Node("minimal_subscriber") { + subscription_ = rclcpp::Node::create_subscription( + "dds_topic", + rclcpp::SystemDefaultsQoS(), + std::bind(&MinimalSubscriber::topic_callback, + this, _1)); + } + + private: + void topic_callback(const ddsmetadata::msg::DDSMetaData::SharedPtr metadata) + const { + unsigned char* ptr = reinterpret_cast + (malloc(sizeof(int64_t))); + for (int i = 0; i < 8; i++) + *(ptr + i) = metadata->mdata[i]; + RCLCPP_INFO(rclcpp::Node::get_logger(), + "ROS2 heard: '%ld'", + *reinterpret_cast(ptr)); + } + rclcpp::Subscription::SharedPtr subscription_; +}; + +int main(int argc, char * argv[]) { + rclcpp::init(argc, argv); + rclcpp::spin(std::make_shared()); + rclcpp::shutdown(); + return 0; +} diff --git a/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_py/package.xml b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_py/package.xml new file mode 100644 index 000000000..4e1edc9c1 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_py/package.xml @@ -0,0 +1,21 @@ + + + + ros_talk_with_dds_py + 0.0.0 + This package is for testing msg_lib dds port talking with ROS2 Python node. + root + BSD-3-Clause + + rclpy + ddsmetadata + + ament_copyright + ament_flake8 + ament_pep257 + python3-pytest + + + ament_python + + diff --git a/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_py/resource/ros_talk_with_dds_py b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_py/resource/ros_talk_with_dds_py new file mode 100644 index 000000000..e69de29bb diff --git a/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_py/ros_talk_with_dds_py/__init__.py b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_py/ros_talk_with_dds_py/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_py/ros_talk_with_dds_py/publisher.py b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_py/ros_talk_with_dds_py/publisher.py new file mode 100644 index 000000000..85ed1a60f --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_py/ros_talk_with_dds_py/publisher.py @@ -0,0 +1,47 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ +import numpy as np +import rclpy +from rclpy.node import Node +from rclpy.qos import qos_profile_system_default + +from ddsmetadata.msg import DDSMetaData +from .utils.np_mdata_trans import nparray_to_metadata + + +class MinimalPublisher(Node): + + def __init__(self): + super().__init__('minimal_publisher') + self.publisher_ = self.create_publisher(DDSMetaData, + 'dds_topic', + qos_profile_system_default) + timer_period = 0.5 # seconds + self.timer = self.create_timer(timer_period, self.timer_callback) + self.i = 0 + + def timer_callback(self): + np_arr = np.array(([self.i, 2, 3]), np.int64) + msg = nparray_to_metadata(np_arr) + self.publisher_.publish(msg) + print("Publishing : ", np_arr) + self.i += 1 + + +def main(args=None): + rclpy.init(args=args) + + minimal_publisher = MinimalPublisher() + + rclpy.spin(minimal_publisher) + + # Destroy the node explicitly + # (optional - otherwise it will be done automatically + # when the garbage collector destroys the node object) + minimal_publisher.destroy_node() + rclpy.shutdown() + + +if __name__ == '__main__': + main() diff --git a/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_py/ros_talk_with_dds_py/subscriber.py b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_py/ros_talk_with_dds_py/subscriber.py new file mode 100644 index 000000000..ae7517872 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_py/ros_talk_with_dds_py/subscriber.py @@ -0,0 +1,44 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + +import rclpy +from rclpy.node import Node +from rclpy.qos import qos_profile_system_default + +from ddsmetadata.msg import DDSMetaData +from .utils.np_mdata_trans import metadata_to_nparray + + +class MinimalSubscriber(Node): + + def __init__(self): + super().__init__('minimal_subscriber') + self.subscription = self.create_subscription( + DDSMetaData, + 'dds_topic', + self.listener_callback, + qos_profile_system_default) + # pylint: disable=W0104 + self.subscription # prevent unused variable warning + + def listener_callback(self, msg): + print("Heard : ", metadata_to_nparray(msg)) + + +def main(args=None): + rclpy.init(args=args) + + minimal_subscriber = MinimalSubscriber() + + rclpy.spin(minimal_subscriber) + + # Destroy the node explicitly + # (optional - otherwise it will be done automatically + # when the garbage collector destroys the node object) + minimal_subscriber.destroy_node() + rclpy.shutdown() + + +if __name__ == '__main__': + main() diff --git a/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_py/ros_talk_with_dds_py/utils/__init__.py b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_py/ros_talk_with_dds_py/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_py/ros_talk_with_dds_py/utils/np_mdata_trans.py b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_py/ros_talk_with_dds_py/utils/np_mdata_trans.py new file mode 100644 index 000000000..2a2ada521 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_py/ros_talk_with_dds_py/utils/np_mdata_trans.py @@ -0,0 +1,60 @@ +import numpy as np +from ddsmetadata.msg import DDSMetaData + + +DTYPE_LIST = [np.dtype(i) for i in [np.bool_, + np.byte, + np.ubyte, + np.short, + np.ushort, + np.intc, + np.uintc, + np.int_, + np.uint, + np.longlong, + np.ulonglong, + np.single, + np.double, + np.longdouble, + np.cfloat, + np.cdouble, + np.clongdouble, + np.object_, + np.string_, + np.unicode_, + np.void]] + + +def metadata_to_nparray(metadata): + ndim = metadata.nd + shape = metadata.dims + dims = list() + for i in range(ndim): + dims.append(shape[i]) + dtype = np.dtype(DTYPE_LIST[metadata.type]) + np_bytes_list = metadata.mdata + np_bytes_array = bytearray(np_bytes_list) + np_array = np.frombuffer(np_bytes_array, dtype) + np_array = np_array.reshape(tuple(dims)) + return np_array + + +def nparray_to_metadata(np_array): + metadata = DDSMetaData() + metadata.nd = np_array.ndim + metadata.type = np_array.dtype.num + metadata.elsize = np_array.itemsize + metadata.total_size = np_array.size + shape_list = list(np_array.shape) + for _ in range(5 - len(shape_list)): + shape_list.append(0) + metadata.dims = shape_list + strides_list = list(np_array.strides) + for _ in range(5 - len(strides_list)): + strides_list.append(0) + metadata.strides = strides_list + np_bytes = np_array.tobytes() + np_bytes_array = np.frombuffer(np_bytes, np.byte) + np_bytes_list = np_bytes_array.tolist() + metadata.mdata = np_bytes_list + return metadata diff --git a/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_py/setup.cfg b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_py/setup.cfg new file mode 100644 index 000000000..58ff572b7 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_py/setup.cfg @@ -0,0 +1,4 @@ +[develop] +script_dir=$base/lib/ros_talk_with_dds_py +[install] +install_scripts=$base/lib/ros_talk_with_dds_py diff --git a/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_py/setup.py b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_py/setup.py new file mode 100644 index 000000000..002c962d2 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/ros_talk_with_dds_py/setup.py @@ -0,0 +1,28 @@ +from setuptools import setup + +package_name = 'ros_talk_with_dds_py' +utils = "ros_talk_with_dds_py/utils" + +setup( + name=package_name, + version='0.0.0', + packages=[package_name, utils], + data_files=[ + ('share/ament_index/resource_index/packages', + ['resource/' + package_name]), + ('share/' + package_name, ['package.xml']), + ], + install_requires=['setuptools'], + zip_safe=True, + maintainer='root', + maintainer_email='he.xu@intel.com', + description='For testing msg_lib dds port talking with ROS2 Python node', + license='BSD-3-Clause', + tests_require=['pytest'], + entry_points={ + 'console_scripts': [ + 'ros_pub = ros_talk_with_dds_py.publisher:main', + 'ros_sub = ros_talk_with_dds_py.subscriber:main', + ], + }, +) diff --git a/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/test_cyclonedds_from_ros.py b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/test_cyclonedds_from_ros.py new file mode 100644 index 000000000..f31294dd4 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/test_cyclonedds_from_ros.py @@ -0,0 +1,31 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + +from lava.magma.runtime.message_infrastructure import ( + ChannelQueueSize, + GetDDSChannel, + DDSTransportType, + DDSBackendType +) + + +def test_ddschannel(): + name = 'rt/dds_topic' + + dds_channel = GetDDSChannel( + name, + DDSTransportType.DDSUDPv4, + DDSBackendType.CycloneDDSBackend, + ChannelQueueSize) + + recv_port = dds_channel.dst_port + recv_port.start() + for _ in range(100): + res = recv_port.recv() + print(res) + recv_port.join() + + +if __name__ == "__main__": + test_ddschannel() diff --git a/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/test_cyclonedds_to_ros.py b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/test_cyclonedds_to_ros.py new file mode 100644 index 000000000..08910420c --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/test_cyclonedds_to_ros.py @@ -0,0 +1,38 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + +import numpy as np +import time + +from lava.magma.runtime.message_infrastructure import ( + ChannelQueueSize, + GetDDSChannel, + DDSTransportType, + DDSBackendType +) + + +def prepare_data(): + return np.random.random_sample((2, 4)) + + +def test_ddschannel(): + name = 'rt/dds_topic' + + dds_channel = GetDDSChannel( + name, + DDSTransportType.DDSUDPv4, + DDSBackendType.CycloneDDSBackend, + ChannelQueueSize) + + send_port = dds_channel.src_port + send_port.start() + for _ in range(100): + send_port.send(prepare_data()) + time.sleep(0.1) + send_port.join() + + +if __name__ == "__main__": + test_ddschannel() diff --git a/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/test_fastdds_from_ros.py b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/test_fastdds_from_ros.py new file mode 100644 index 000000000..f6951b345 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/test_fastdds_from_ros.py @@ -0,0 +1,32 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + + +from lava.magma.runtime.message_infrastructure import ( + ChannelQueueSize, + GetDDSChannel, + DDSTransportType, + DDSBackendType +) + + +def test_ddschannel(): + name = 'rt/dds_topic' + + dds_channel = GetDDSChannel( + name, + DDSTransportType.DDSUDPv4, + DDSBackendType.FASTDDSBackend, + ChannelQueueSize) + + recv_port = dds_channel.dst_port + recv_port.start() + for _ in range(100): + res = recv_port.recv() + print(res) + recv_port.join() + + +if __name__ == "__main__": + test_ddschannel() diff --git a/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/test_fastdds_to_ros.py b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/test_fastdds_to_ros.py new file mode 100644 index 000000000..824309556 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/examples/ros2/test_fastdds_to_ros.py @@ -0,0 +1,38 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + +import numpy as np +import time + +from lava.magma.runtime.message_infrastructure import ( + ChannelQueueSize, + GetDDSChannel, + DDSTransportType, + DDSBackendType +) + + +def prepare_data(): + return np.random.random_sample((2, 4)) + + +def test_ddschannel(): + name = 'rt/dds_topic' + + dds_channel = GetDDSChannel( + name, + DDSTransportType.DDSUDPv4, + DDSBackendType.FASTDDSBackend, + ChannelQueueSize) + + send_port = dds_channel.src_port + send_port.start() + for _ in range(100): + send_port.send(prepare_data()) + time.sleep(0.1) + send_port.join() + + +if __name__ == "__main__": + test_ddschannel() diff --git a/src/lava/magma/runtime/_c_message_infrastructure/setenv.sh b/src/lava/magma/runtime/_c_message_infrastructure/setenv.sh new file mode 100644 index 000000000..0418ba03d --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/setenv.sh @@ -0,0 +1,10 @@ +#!/bin/bash +SCRIPTPATH=$(cd $(dirname -- "$BASH_SOURCE") && pwd) +export MSG_LOG_PATH="${SCRIPTPATH}/log" +if [ -d "$MSG_LOG_PATH" ]; then + if [ -n $(find "$MSG_LOG_PATH" -maxdepth 1 -name 'lava_message_infrastructure_pid_*.log') ]; then + rm "$MSG_LOG_PATH/lava_message_infrastructure_pid_*.log" + fi +else + mkdir -p "$MSG_LOG_PATH" +fi diff --git a/src/lava/magma/runtime/_c_message_infrastructure/test/CMakeLists.txt b/src/lava/magma/runtime/_c_message_infrastructure/test/CMakeLists.txt new file mode 100644 index 000000000..278ebe224 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/test/CMakeLists.txt @@ -0,0 +1,66 @@ +include(FetchContent) +FetchContent_Declare( + googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG release-1.12.1 +) +# For Windows: Prevent overriding the parent project's compiler/linker settings +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) +FetchContent_MakeAvailable(googletest) + +include_directories(../csrc) + +add_executable( + test_messaging_infrastructure + test_multiprocessing.cc + test_channel.cc + test_shm_delivery.cc + test_socket_delivery.cc + $<$:test_grpc_delivery.cc> + $<$:test_ddschannel.cc> + ) + + +target_link_libraries( + test_messaging_infrastructure + GTest::gtest_main + message_infrastructure + rt) + +if(FASTDDS_ENABLE) + add_executable( + test_fastdds_to_ros + test_fastdds_to_ros.cc) + + add_executable( + test_fastdds_from_ros + test_fastdds_from_ros.cc) + + target_link_libraries( + test_fastdds_to_ros + message_infrastructure) + + target_link_libraries( + test_fastdds_from_ros + message_infrastructure) + +elseif(CycloneDDS_ENABLE) + add_executable( + test_cyclonedds_to_ros + test_cyclonedds_to_ros.cc) + + add_executable( + test_cyclonedds_from_ros + test_cyclonedds_from_ros.cc) + + target_link_libraries( + test_cyclonedds_to_ros + message_infrastructure) + + target_link_libraries( + test_cyclonedds_from_ros + message_infrastructure) +endif() + +include(GoogleTest) +gtest_discover_tests(test_messaging_infrastructure PROPERTIES TIMEOUT 10) \ No newline at end of file diff --git a/src/lava/magma/runtime/_c_message_infrastructure/test/test_channel.cc b/src/lava/magma/runtime/_c_message_infrastructure/test/test_channel.cc new file mode 100644 index 000000000..cc98fdcb9 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/test/test_channel.cc @@ -0,0 +1,120 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include +#include +#include +#include +#include + +namespace message_infrastructure { + +class Builder { + public: + void Build() {} +}; + +MetaDataPtr ExpectData() { + auto metadata = std::make_shared(); + int32_t data[5] = {1, 3, 5, 7, 9}; + int32_t *data_ptr = data; + metadata->mdata = reinterpret_cast(data); + // metadata->mdata = (void*)data; + return metadata; +} + +void SendProc(AbstractSendPortPtr send_port, + MetaDataPtr data) { + AbstractActor::StopFn stop_fn; + std::cout << "Here I am" << std::endl; + + // Sends data + send_port->Start(); + send_port->Send(data); + send_port->Join(); + std::cout << "Status STOPPED" << std::endl; +} + +void RecvProc(AbstractRecvPortPtr recv_port) { + std::cout << "Here I am (RECV)" << std::endl; + + // Returns received data + recv_port->Start(); + auto recv_data = recv_port->Recv(); + recv_port->Join(); + std::cout << "Status STOPPED" << std::endl; +} + +TEST(TestSharedMemory, SharedMemSendReceive) { + // Creates a pair of send and receive ports + // Expects that data sent is the same as data received + // Create Shared Memory Channel + int size = 1; + int nbytes = sizeof(int); + std::string name = "test_shmem_channel"; + std::string src_name = "Source1"; + std::string dst_name = "Dest1"; + auto shmem_channel = ShmemChannel( + src_name, + dst_name, + size, + nbytes); + + AbstractSendPortPtr send_port = shmem_channel.GetSendPort(); + AbstractRecvPortPtr recv_port = shmem_channel.GetRecvPort(); + + MultiProcessing mp; + Builder *builder = new Builder(); + + AbstractActor::TargetFn send_target_fn; + AbstractActor::TargetFn recv_target_fn; + + auto data = ExpectData(); + auto send_bound_fn = std::bind(&SendProc, + send_port, + data); + send_target_fn = send_bound_fn; + + auto recv_bound_fn = std::bind(&RecvProc, + recv_port); + recv_target_fn = recv_bound_fn; + + mp.BuildActor(send_target_fn); + mp.BuildActor(recv_target_fn); + + sleep(2); + + // Stop any currently running actors + mp.Stop(); + mp.Cleanup(true); +} + +TEST(TestSharedMemory, SharedMemSingleProcess) { + // metadata generate + MetaDataPtr metadata = std::make_shared(); + int64_t dims[] = {10000, 0, 0, 0, 0}; + int64_t nd = 1; + int64_t* array_ = reinterpret_cast + (malloc(sizeof(int64_t) * dims[0])); + memset(array_, 0, sizeof(int64_t) * dims[0]); + std::fill(array_, array_ + 10, 1); + GetMetadata(metadata, array_, nd, + static_cast(METADATA_TYPES::LONG), dims); + AbstractChannelPtr shmchannel = GetChannelFactory().GetChannel( + ChannelType::SHMEMCHANNEL, + 1, + sizeof(int64_t)*10000, + "send", + "recv"); + + auto SendPort = shmchannel->GetSendPort(); + auto RecvPort = shmchannel->GetRecvPort(); + SendPort->Send(metadata); + MetaDataPtr received_data = RecvPort->Recv(); + EXPECT_EQ(10000, metadata->total_size); + EXPECT_EQ(1, *reinterpret_cast(metadata->mdata)); +} + +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/test/test_cyclonedds_from_ros.cc b/src/lava/magma/runtime/_c_message_infrastructure/test/test_cyclonedds_from_ros.cc new file mode 100644 index 000000000..66cd0cdaa --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/test/test_cyclonedds_from_ros.cc @@ -0,0 +1,32 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include +#include +#include +using namespace message_infrastructure; // NOLINT + +#define LOOP_NUM 100 + +int main() { + auto dds_channel = GetChannelFactory() + .GetDDSChannel("test_channel_src", + "test_channel_dst", + "rt/dds_topic", + 10, + DEFAULT_NBYTES, + DDSTransportType::DDSUDPv4, + DDSBackendType::CycloneDDSBackend); + auto dds_recv = dds_channel->GetRecvPort(); + int loop = LOOP_NUM; + + dds_recv->Start(); + while (loop--) { + MetaDataPtr res = dds_recv->Recv(); + printf("DDS recv : '%d'\n", *reinterpret_cast(res->mdata)); + } + dds_recv->Join(); + return 0; +} diff --git a/src/lava/magma/runtime/_c_message_infrastructure/test/test_cyclonedds_to_ros.cc b/src/lava/magma/runtime/_c_message_infrastructure/test/test_cyclonedds_to_ros.cc new file mode 100644 index 000000000..783beba26 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/test/test_cyclonedds_to_ros.cc @@ -0,0 +1,42 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include +#include + +using namespace message_infrastructure; // NOLINT + +#define LOOP_NUM 100 + +int main() { + auto dds_channel = GetChannelFactory() + .GetDDSChannel("test_cyclonedds_src", + "test_cyclonedds_dst", + "rt/dds_topic", + 10, + sizeof(char), + DDSTransportType::DDSUDPv4, + DDSBackendType::CycloneDDSBackend); + auto dds_send = dds_channel->GetSendPort(); + int loop = LOOP_NUM; + + dds_send->Start(); + MetaDataPtr metadata = std::make_shared(); + metadata->nd = 1; + metadata->type = 7; + metadata->elsize = 1; + metadata->total_size = 1; + metadata->dims[0] = 1; + metadata->strides[0] = 1; + metadata->mdata = reinterpret_cast(malloc(sizeof(char))); + while (loop--) { + *reinterpret_cast(metadata->mdata) = loop % 255; + dds_send->Send(metadata); + printf("DDS send : '%d'\n", loop); + sleep(1); + } + dds_send->Join(); + return 0; +} diff --git a/src/lava/magma/runtime/_c_message_infrastructure/test/test_ddschannel.cc b/src/lava/magma/runtime/_c_message_infrastructure/test/test_ddschannel.cc new file mode 100644 index 000000000..ed0f4a14c --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/test/test_ddschannel.cc @@ -0,0 +1,215 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include +#include +#include +#include + +namespace message_infrastructure { + +const size_t DATA_LENGTH = 10000; +const uint32_t loop_number = 1000; +const size_t DEPTH = 32; + +void dds_stop_fn() { + // exit(0); +} + +void dds_target_fn_a1_bound(int loop, + AbstractChannelPtr mp_to_a1, + AbstractChannelPtr a1_to_mp, + AbstractChannelPtr a1_to_a2, + AbstractChannelPtr a2_to_a1) { + auto from_mp = mp_to_a1->GetRecvPort(); + from_mp->Start(); + auto to_mp = a1_to_mp->GetSendPort(); + to_mp->Start(); + auto to_a2 = a1_to_a2->GetSendPort(); + to_a2->Start(); + auto from_a2 = a2_to_a1->GetRecvPort(); + from_a2->Start(); + while (loop--) { + MetaDataPtr data = from_mp->Recv(); + (*reinterpret_cast(data->mdata))++; + to_a2->Send(data); + free(data->mdata); + data = from_a2->Recv(); + (*reinterpret_cast(data->mdata))++; + to_mp->Send(data); + free(data->mdata); + } + from_mp->Join(); + from_a2->Join(); + to_a2->Join(); + to_mp->Join(); +} + +void dds_target_fn_a2_bound(int loop, + AbstractChannelPtr a1_to_a2, + AbstractChannelPtr a2_to_a1) { + auto from_a1 = a1_to_a2->GetRecvPort(); + from_a1->Start(); + auto to_a1 = a2_to_a1->GetSendPort(); + to_a1->Start(); + while (loop--) { + MetaDataPtr data = from_a1->Recv(); + (*reinterpret_cast(data->mdata))++; + to_a1->Send(data); + free(data->mdata); + } + from_a1->Join(); + to_a1->Join(); +} + +void dds_protocol(std::string topic_name, + DDSTransportType transfer_type, + DDSBackendType dds_backend) { + MultiProcessing mp; + int loop = loop_number; + AbstractChannelPtr mp_to_a1 = GetChannelFactory() + .GetDDSChannel(topic_name + "mp_to_a1", + topic_name + "mp_to_a1", + "mp_to_a1_topic", + DEPTH, + sizeof(int64_t)*10000, + transfer_type, + dds_backend); + AbstractChannelPtr a1_to_mp = GetChannelFactory() + .GetDDSChannel(topic_name + "a1_to_mp", + topic_name + "a1_to_mp", + "a1_to_mp_topic", + DEPTH, + sizeof(int64_t)*10000, + transfer_type, + dds_backend); + AbstractChannelPtr a1_to_a2 = GetChannelFactory() + .GetDDSChannel(topic_name + "a1_to_a2", + topic_name + "a1_to_a2", + "a1_to_a2_topic", + DEPTH, + sizeof(int64_t)*10000, + transfer_type, + dds_backend); + AbstractChannelPtr a2_to_a1 = GetChannelFactory() + .GetDDSChannel(topic_name + "a2_to_a1", + topic_name + "a2_to_a1", + "a2_to_a1_topic", + DEPTH, + sizeof(int64_t)*10000, + transfer_type, + dds_backend); + + auto target_fn_a1 = std::bind(&dds_target_fn_a1_bound, loop, + mp_to_a1, a1_to_mp, a1_to_a2, + a2_to_a1); + auto target_fn_a2 = std::bind(&dds_target_fn_a2_bound, loop, a1_to_a2, + a2_to_a1); + + ProcessType actor1 = mp.BuildActor(target_fn_a1); + ProcessType actor2 = mp.BuildActor(target_fn_a2); + + auto to_a1 = mp_to_a1->GetSendPort(); + to_a1->Start(); + auto from_a1 = a1_to_mp->GetRecvPort(); + from_a1->Start(); + MetaDataPtr metadata = std::make_shared(); + int64_t dims[] = {10000, 0, 0, 0, 0}; + int64_t nd = 1; + int64_t* array_ = reinterpret_cast + (malloc(sizeof(int64_t) * dims[0])); + memset(array_, 0, sizeof(int64_t) * dims[0]); + std::fill(array_, array_ + 10, 1); + GetMetadata(metadata, array_, nd, + static_cast(METADATA_TYPES::LONG), dims); + MetaDataPtr mptr; + LAVA_DUMP(LOG_UTTEST, "main process loop: %d\n", loop); + const clock_t start_time = std::clock(); + while (loop--) { + to_a1->Send(metadata); + free(metadata->mdata); + mptr = from_a1->Recv(); + metadata = mptr; + } + const clock_t end_time = std::clock(); + from_a1->Join(); + to_a1->Join(); + mp.Stop(); + mp.Cleanup(true); + + std::printf("dds cpp loop timedelta: %f\n", + ((end_time - start_time)/static_cast(CLOCKS_PER_SEC))); + LAVA_DUMP(LOG_UTTEST, "exit\n"); +} + +#if defined(FASTDDS_ENABLE) +TEST(TestDDSDelivery, FastDDSSHMLoop) { + GTEST_SKIP(); + dds_protocol("fast_shm_", + DDSTransportType::DDSSHM, + DDSBackendType::FASTDDSBackend); +} +TEST(TestDDSDelivery, FastDDSUDPv4Loop) { + dds_protocol("fast_UDPv4", + DDSTransportType::DDSUDPv4, + DDSBackendType::FASTDDSBackend); +} +#endif + +#if defined(CycloneDDS_ENABLE) +TEST(TestDDSDelivery, CycloneDDSUDPv4Loop) { + dds_protocol("cyclone_UDPv4", + DDSTransportType::DDSUDPv4, + DDSBackendType::CycloneDDSBackend); +} +#endif + +TEST(TestDDSSingleProcess, DDS1Process) { + GTEST_SKIP(); + LAVA_DUMP(LOG_UTTEST, "TestDDSSingleProcess starts.\n"); + AbstractChannelPtr dds_channel = GetChannelFactory() + .GetDDSChannel("test_DDSChannel_src", + "test_DDSChannel_dst", + "TestDDSSingleProcess_topic", + 5, + sizeof(int64_t), + DDSTransportType::DDSSHM, + DDSBackendType::FASTDDSBackend); + + auto send_port = dds_channel->GetSendPort(); + send_port->Start(); + auto recv_port = dds_channel->GetRecvPort(); + recv_port->Start(); + + MetaDataPtr metadata = std::make_shared(); + metadata->nd = 1; + metadata->type = 7; + metadata->elsize = 8; + metadata->total_size = 1; + metadata->dims[0] = 1; + metadata->strides[0] = 1; + metadata->mdata = + (reinterpret_cast + (malloc(sizeof(int64_t)+sizeof(MetaData)))+sizeof(MetaData)); + *reinterpret_cast(metadata->mdata) = 1; + + MetaDataPtr mptr; + int loop = loop_number; + int i = 0; + while (loop--) { + if (!(loop % 1000)) + LAVA_DUMP(LOG_UTTEST, "At iteration : %d * 1000\n", i++); + send_port->Send(metadata); + free(metadata->mdata); + mptr = recv_port->Recv(); + EXPECT_EQ(*reinterpret_cast(mptr->mdata), + *reinterpret_cast(metadata->mdata)); + (*reinterpret_cast(mptr->mdata))++; + metadata = mptr; + } + recv_port->Join(); + send_port->Join(); +} +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/test/test_fastdds_from_ros.cc b/src/lava/magma/runtime/_c_message_infrastructure/test/test_fastdds_from_ros.cc new file mode 100644 index 000000000..69d5fbb3c --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/test/test_fastdds_from_ros.cc @@ -0,0 +1,32 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include +#include + +using namespace message_infrastructure; // NOLINT + +#define LOOP_NUM 100 + +int main() { + auto dds_channel = GetChannelFactory() + .GetDDSChannel("test_channel_src", + "test_channel_dst", + "rt/dds_topic", + 10, + DEFAULT_NBYTES, + DDSTransportType::DDSUDPv4, + DDSBackendType::FASTDDSBackend); + auto dds_recv = dds_channel->GetRecvPort(); + int loop = LOOP_NUM; + + dds_recv->Start(); + while (loop--) { + MetaDataPtr res = dds_recv->Recv(); + printf("DDS recv : '%d'\n", *reinterpret_cast(res->mdata)); + } + dds_recv->Join(); + return 0; +} diff --git a/src/lava/magma/runtime/_c_message_infrastructure/test/test_fastdds_to_ros.cc b/src/lava/magma/runtime/_c_message_infrastructure/test/test_fastdds_to_ros.cc new file mode 100644 index 000000000..8559388dc --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/test/test_fastdds_to_ros.cc @@ -0,0 +1,42 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include +#include + +using namespace message_infrastructure; // NOLINT + +#define LOOP_NUM 100 + +int main() { + auto dds_channel = GetChannelFactory() + .GetDDSChannel("test_fastdds_src", + "test_fastdds_dst", + "rt/dds_topic", + 10, + sizeof(int64_t), + DDSTransportType::DDSUDPv4, + DDSBackendType::FASTDDSBackend); + auto dds_send = dds_channel->GetSendPort(); + int loop = LOOP_NUM; + + dds_send->Start(); + MetaDataPtr metadata = std::make_shared(); + metadata->nd = 1; + metadata->type = 7; + metadata->elsize = 8; + metadata->total_size = 1; + metadata->dims[0] = 1; + metadata->strides[0] = 1; + metadata->mdata = reinterpret_cast(malloc(sizeof(int64_t))); + while (loop--) { + *reinterpret_cast(metadata->mdata) = loop; + dds_send->Send(metadata); + printf("DDS send : '%d'\n", loop); + sleep(0.1); + } + dds_send->Join(); + return 0; +} diff --git a/src/lava/magma/runtime/_c_message_infrastructure/test/test_grpc_delivery.cc b/src/lava/magma/runtime/_c_message_infrastructure/test/test_grpc_delivery.cc new file mode 100644 index 000000000..dd1439662 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/test/test_grpc_delivery.cc @@ -0,0 +1,149 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include +#include +#include +#include +#include +#include + +namespace message_infrastructure { + +static void stop_fn() { + // exit(0); +} + +void grpc_target_fn1( + int loop, + AbstractChannelPtr mp_to_a1, + AbstractChannelPtr a1_to_mp, + AbstractChannelPtr a1_to_a2, + AbstractChannelPtr a2_to_a1) { + auto from_mp = mp_to_a1->GetRecvPort(); + auto to_mp = a1_to_mp->GetSendPort(); + auto to_a2 = a1_to_a2->GetSendPort(); + auto from_a2 = a2_to_a1->GetRecvPort(); + from_mp->Start(); + to_mp->Start(); + to_a2->Start(); + from_a2->Start(); + LAVA_DUMP(1, "grpc actor1, loop: %d\n", loop); + while (loop--) { + LAVA_DUMP(LOG_UTTEST, "grpc actor1 waitting\n"); + MetaDataPtr data = from_mp->Recv(); + LAVA_DUMP(LOG_UTTEST, "grpc actor1 recviced\n"); + (*reinterpret_cast(data->mdata))++; + to_a2->Send(MetaData2GrpcMetaData(data)); + free(reinterpret_cast(data->mdata)); + data = from_a2->Recv(); + (*reinterpret_cast(data->mdata))++; + to_mp->Send(MetaData2GrpcMetaData(data)); + free(reinterpret_cast(data->mdata)); + } + from_mp->Join(); + to_mp->Join(); + to_a2->Join(); + from_a2->Join(); + } + +void grpc_target_fn2( + int loop, + AbstractChannelPtr a1_to_a2, + AbstractChannelPtr a2_to_a1) { + auto to_a1 = a2_to_a1->GetSendPort(); + auto from_a1 = a1_to_a2->GetRecvPort(); + from_a1->Start(); + to_a1->Start(); + LAVA_DUMP(LOG_UTTEST, "grpc actor2, loop: %d\n", loop); + while (loop--) { + LAVA_DUMP(LOG_UTTEST, "grpc actor2 waitting\n"); + MetaDataPtr data = from_a1->Recv(); + LAVA_DUMP(LOG_UTTEST, "grpc actor2 recviced\n"); + (*reinterpret_cast(data->mdata))++; + to_a1->Send(MetaData2GrpcMetaData(data)); + free(reinterpret_cast(data->mdata)); + } + from_a1->Join(); + to_a1->Join(); + } + +TEST(TestGRPCChannel, GRPCLoop) { + MultiProcessing mp; + int loop = 1000; + AbstractChannelPtr mp_to_a1 = GetChannelFactory().GetDefRPCChannel( + "mp_to_a1", "mp_to_a1", 6); + AbstractChannelPtr a1_to_mp = GetChannelFactory().GetDefRPCChannel( + "a1_to_mp", "a1_to_mp", 6); + AbstractChannelPtr a1_to_a2 = GetChannelFactory().GetDefRPCChannel( + "a1_to_a2", "a1_to_a2", 6); + AbstractChannelPtr a2_to_a1 = GetChannelFactory().GetDefRPCChannel( + "a2_to_a1", "a2_to_a1", 6); + auto target_fn_a1 = std::bind(&grpc_target_fn1, + loop, + mp_to_a1, + a1_to_mp, + a1_to_a2, + a2_to_a1); + auto target_fn_a2 = std::bind(&grpc_target_fn2, + loop, + a1_to_a2, + a2_to_a1); + ProcessType actor1 = mp.BuildActor(target_fn_a1); + ProcessType actor2 = mp.BuildActor(target_fn_a2); + auto to_a1 = mp_to_a1->GetSendPort(); + to_a1->Start(); + auto from_a1 = a1_to_mp->GetRecvPort(); + from_a1->Start(); + MetaDataPtr metadata = std::make_shared(); + int64_t dims[] = {10000, 0, 0, 0, 0}; + int64_t nd = 1; + int64_t* array_ = reinterpret_cast + (malloc(sizeof(int64_t) * dims[0])); + memset(array_, 0, sizeof(int64_t) * dims[0]); + std::fill(array_, array_ + 10, 1); + GetMetadata(metadata, array_, nd, + static_cast(METADATA_TYPES::LONG), dims); + int expect_result = 1 + loop * 3; + const clock_t start_time = std::clock(); + while (loop--) { + LAVA_DUMP(LOG_UTTEST, "wait for response, remain loop: %d\n", loop); + to_a1->Send(MetaData2GrpcMetaData(metadata)); + free(reinterpret_cast(metadata->mdata)); + metadata = from_a1->Recv(); + LAVA_DUMP(LOG_UTTEST, "metadata:\n"); + LAVA_DUMP(LOG_UTTEST, "nd: %ld\n", metadata->nd); + LAVA_DUMP(LOG_UTTEST, "type: %ld\n", metadata->type); + LAVA_DUMP(LOG_UTTEST, "elsize: %ld\n", metadata->elsize); + LAVA_DUMP(LOG_UTTEST, "total_size: %ld\n", metadata->total_size); + LAVA_DUMP(LOG_UTTEST, "dims: {%ld, %ld, %ld, %ld, %ld}\n", + metadata->dims[0], metadata->dims[1], metadata->dims[2], + metadata->dims[3], metadata->dims[4]); + LAVA_DUMP(LOG_UTTEST, "strides: {%ld, %ld, %ld, %ld, %ld}\n", + metadata->strides[0], metadata->strides[1], metadata->strides[2], + metadata->strides[3], metadata->strides[4]); + LAVA_DUMP(LOG_UTTEST, "grpc mdata: %p, grpc *mdata: %ld\n", metadata->mdata, + *reinterpret_cast(metadata->mdata)); + } + const clock_t end_time = std::clock(); + to_a1->Join(); + from_a1->Join(); + int64_t result = *reinterpret_cast(metadata->mdata); + free(reinterpret_cast(metadata->mdata)); + LAVA_DUMP(LOG_UTTEST, "grpc result =%ld\n", result); + mp.Stop(); + mp.Cleanup(true); + if (result != expect_result) { + LAVA_DUMP(LOG_UTTEST, "expect_result: %d\n", expect_result); + LAVA_DUMP(LOG_UTTEST, "result: %ld\n", result); + LAVA_LOG_ERR("result != expect_result\n"); + throw; + } + std::printf("grpc cpp loop timedelta: %f\n", + ((end_time - start_time)/static_cast(CLOCKS_PER_SEC))); + LAVA_DUMP(LOG_UTTEST, "exit\n"); +} + +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/test/test_multiprocessing.cc b/src/lava/magma/runtime/_c_message_infrastructure/test/test_multiprocessing.cc new file mode 100644 index 000000000..1f22982c0 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/test/test_multiprocessing.cc @@ -0,0 +1,86 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include +#include +#include + +namespace message_infrastructure { + +class Builder { + public: + void Build(int i); +}; + +void Builder::Build(int i) { + std::cout << "Builder running build " << i << std::endl; + std::cout << "Build " << i << "... Sleeping for 3s" << std::endl; + sleep(3); + std::cout << "Build " << i << "... Builder complete" << std::endl; +} + +void TargetFunction(Builder builder, int idx) { + std::cout << "Target Function running... ID " << idx << std::endl; + builder.Build(idx); +} + +TEST(TestMultiprocessing, MultiprocessingSpawn) { + // Spawns an actor + // Checks that actor is spawned successfully + GTEST_SKIP(); + Builder *builder = new Builder(); + MultiProcessing mp; + + AbstractActor::TargetFn target_fn; + + for (int i = 0; i < 1; i++) { + std::cout << "Loop " << i << std::endl; + auto bound_fn = std::bind(&TargetFunction, + (*builder), + i); + target_fn = bound_fn; + ProcessType return_value = mp.BuildActor(bound_fn); + std::cout << "Return Value --> " + << static_cast(return_value) + << std::endl; + } + + std::vector& actorList = mp.GetActors(); + std::cout << "Actor List Length --> " << actorList.size() << std::endl; + + // Stop any currently running actors + mp.Stop(); + mp.Cleanup(true); +} + +TEST(TestMultiprocessing, ActorStop) { + GTEST_SKIP(); + // Force stops all running actors + // Checks that actor status returns 1 (StatusStopped) + MultiProcessing mp; + Builder *builder = new Builder(); + AbstractActor::TargetFn target_fn; + + for (int i = 0; i < 5; i++) { + std::cout << "Loop " << i << std::endl; + auto bound_fn = std::bind(&TargetFunction, + (*builder), + i); + target_fn = bound_fn; + ProcessType return_value = mp.BuildActor(bound_fn); + std::cout << "Return Value --> " + << static_cast(return_value) + << std::endl; + } + + sleep(1); + + std::vector& actorList = mp.GetActors(); + std::cout << "Actor List Length --> " << actorList.size() << std::endl; + mp.Stop(); + mp.Cleanup(true); +} + +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/test/test_shm_delivery.cc b/src/lava/magma/runtime/_c_message_infrastructure/test/test_shm_delivery.cc new file mode 100644 index 000000000..e6b8098a2 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/test/test_shm_delivery.cc @@ -0,0 +1,159 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include +#include +#include +#include +#include + +namespace message_infrastructure { + +void stop_fn() { + // exit(0); +} + +void target_fn_a1_bound( + int loop, + AbstractChannelPtr mp_to_a1, + AbstractChannelPtr a1_to_mp, + AbstractChannelPtr a1_to_a2, + AbstractChannelPtr a2_to_a1) { + auto from_mp = mp_to_a1->GetRecvPort(); + from_mp->Start(); + auto to_mp = a1_to_mp->GetSendPort(); + to_mp->Start(); + auto to_a2 = a1_to_a2->GetSendPort(); + to_a2->Start(); + auto from_a2 = a2_to_a1->GetRecvPort(); + from_a2->Start(); + LAVA_DUMP(LOG_UTTEST, "shm actor1, loop: %d\n", loop); + while (loop--) { + LAVA_DUMP(LOG_UTTEST, "shm actor1 waitting\n"); + MetaDataPtr data = from_mp->Recv(); + LAVA_DUMP(LOG_UTTEST, "shm actor1 recviced\n"); + (*reinterpret_cast(data->mdata))++; + to_a2->Send(data); + free(data->mdata); + data = from_a2->Recv(); + (*reinterpret_cast(data->mdata))++; + to_mp->Send(data); + free(data->mdata); + } + from_mp->Join(); + from_a2->Join(); + } + +void target_fn_a2_bound( + int loop, + AbstractChannelPtr a1_to_a2, + AbstractChannelPtr a2_to_a1) { + auto from_a1 = a1_to_a2->GetRecvPort(); + from_a1->Start(); + auto to_a1 = a2_to_a1->GetSendPort(); + to_a1->Start(); + LAVA_DUMP(LOG_UTTEST, "shm actor2, loop: %d\n", loop); + while (loop--) { + LAVA_DUMP(LOG_UTTEST, "shm actor2 waitting\n"); + MetaDataPtr data = from_a1->Recv(); + LAVA_DUMP(LOG_UTTEST, "shm actor2 recviced\n"); + (*reinterpret_cast(data->mdata))++; + to_a1->Send(data); + free(data->mdata); + } + from_a1->Join(); + } + +TEST(TestShmDelivery, ShmLoop) { + MultiProcessing mp; + int loop = 1000; + const int queue_size = 1; + AbstractChannelPtr mp_to_a1 = GetChannelFactory().GetChannel( + ChannelType::SHMEMCHANNEL, + queue_size, + sizeof(int64_t)*10000, + "mp_to_a1", + "mp_to_a1"); + AbstractChannelPtr a1_to_mp = GetChannelFactory().GetChannel( + ChannelType::SHMEMCHANNEL, + queue_size, + sizeof(int64_t)*10000, + "a1_to_mp", + "a1_to_mp"); + AbstractChannelPtr a1_to_a2 = GetChannelFactory().GetChannel( + ChannelType::SHMEMCHANNEL, + queue_size, + sizeof(int64_t)*10000, + "a1_to_a2", + "a1_to_a2"); + AbstractChannelPtr a2_to_a1 = GetChannelFactory().GetChannel( + ChannelType::SHMEMCHANNEL, + queue_size, + sizeof(int64_t)*10000, + "a2_to_a1", + "a2_to_a1"); + auto target_fn_a1 = std::bind(&target_fn_a1_bound, loop, + mp_to_a1, a1_to_mp, a1_to_a2, + a2_to_a1); + auto target_fn_a2 = std::bind(&target_fn_a2_bound, loop, a1_to_a2, + a2_to_a1); + ProcessType actor1 = mp.BuildActor(target_fn_a1); + ProcessType actor2 = mp.BuildActor(target_fn_a2); + auto to_a1 = mp_to_a1->GetSendPort(); + to_a1->Start(); + auto from_a1 = a1_to_mp->GetRecvPort(); + from_a1->Start(); + MetaDataPtr metadata = std::make_shared(); + int64_t dims[] = {10000, 0, 0, 0, 0}; + int64_t nd = 1; + int64_t* array_ = reinterpret_cast + (malloc(sizeof(int64_t) * dims[0])); + memset(array_, 0, sizeof(int64_t) * dims[0]); + std::fill(array_, array_ + 10, 1); + GetMetadata(metadata, array_, nd, + static_cast(METADATA_TYPES::LONG), dims); + + LAVA_DUMP(LOG_UTTEST, "main process loop: %d\n", loop); + int expect_result = 1 + loop * 3; + const clock_t start_time = std::clock(); + while (loop--) { + LAVA_DUMP(LOG_UTTEST, "shm wait for response, remain loop: %d\n", loop); + to_a1->Send(metadata); + free(reinterpret_cast(metadata->mdata)); + metadata = from_a1->Recv(); + LAVA_DUMP(LOG_UTTEST, "metadata:\n"); + LAVA_DUMP(LOG_UTTEST, "nd: %ld\n", metadata->nd); + LAVA_DUMP(LOG_UTTEST, "type: %ld\n", metadata->type); + LAVA_DUMP(LOG_UTTEST, "elsize: %ld\n", metadata->elsize); + LAVA_DUMP(LOG_UTTEST, "total_size: %ld\n", metadata->total_size); + LAVA_DUMP(LOG_UTTEST, "dims: {%ld, %ld, %ld, %ld, %ld}\n", + metadata->dims[0], metadata->dims[1], metadata->dims[2], + metadata->dims[3], metadata->dims[4]); + LAVA_DUMP(LOG_UTTEST, "strides: {%ld, %ld, %ld, %ld, %ld}\n", + metadata->strides[0], metadata->strides[1], metadata->strides[2], + metadata->strides[3], metadata->strides[4]); + LAVA_DUMP(LOG_UTTEST, "mdata: %p, *mdata: %ld\n", metadata->mdata, + *reinterpret_cast(metadata->mdata)); + } + const clock_t end_time = std::clock(); + int64_t result = *reinterpret_cast(metadata->mdata); + free(reinterpret_cast(metadata->mdata)); + LAVA_DUMP(LOG_UTTEST, "shm result =%ld", result); + to_a1->Join(); + from_a1->Join(); + mp.Stop(); + mp.Cleanup(true); + if (result != expect_result) { + LAVA_DUMP(LOG_UTTEST, "expect_result: %d\n", expect_result); + LAVA_DUMP(LOG_UTTEST, "result: %ld\n", result); + LAVA_LOG_ERR("result != expect_result\n"); + throw; + } + std::printf("shm cpp loop timedelta: %f\n", + ((end_time - start_time)/static_cast(CLOCKS_PER_SEC))); + LAVA_DUMP(LOG_UTTEST, "exit\n"); +} + +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/test/test_socket_delivery.cc b/src/lava/magma/runtime/_c_message_infrastructure/test/test_socket_delivery.cc new file mode 100644 index 000000000..794ed6894 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/test/test_socket_delivery.cc @@ -0,0 +1,165 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace message_infrastructure { +static void stop_fn() { + // exit(0); +} + +void soket_target_fn1( + int loop, + AbstractChannelPtr mp_to_a1, + AbstractChannelPtr a1_to_mp, + AbstractChannelPtr a1_to_a2, + AbstractChannelPtr a2_to_a1) { + auto from_mp = mp_to_a1->GetRecvPort(); + auto to_mp = a1_to_mp->GetSendPort(); + auto to_a2 = a1_to_a2->GetSendPort(); + auto from_a2 = a2_to_a1->GetRecvPort(); + from_mp->Start(); + to_mp->Start(); + to_a2->Start(); + from_a2->Start(); + LAVA_DUMP(LOG_UTTEST, "socket actor1, loop: %d\n", loop); + while (loop--) { + LAVA_DUMP(LOG_UTTEST, "soket actor1 waitting\n"); + MetaDataPtr data = from_mp->Recv(); + LAVA_DUMP(LOG_UTTEST, "socket actor1 recviced\n"); + (*reinterpret_cast(data->mdata))++; + to_a2->Send(data); + free(reinterpret_cast(data->mdata)); + data = from_a2->Recv(); + (*reinterpret_cast(data->mdata))++; + to_mp->Send(data); + free(reinterpret_cast(data->mdata)); + } + from_mp->Join(); + to_mp->Join(); + to_a2->Join(); + from_a2->Join(); + } + +void soket_target_fn2( + int loop, + AbstractChannelPtr a1_to_a2, + AbstractChannelPtr a2_to_a1) { + auto to_a1 = a2_to_a1->GetSendPort(); + auto from_a1 = a1_to_a2->GetRecvPort(); + from_a1->Start(); + to_a1->Start(); + LAVA_DUMP(1, "socket actor2, loop: %d\n", loop); + while (loop--) { + LAVA_DUMP(LOG_UTTEST, "socket actor2 waitting\n"); + MetaDataPtr data = from_a1->Recv(); + LAVA_DUMP(LOG_UTTEST, "socket actor2 recviced\n"); + (*reinterpret_cast(data->mdata))++; + to_a1->Send(data); + free(reinterpret_cast(data->mdata)); + } + from_a1->Join(); + to_a1->Join(); + } + +TEST(TestSocketChannel, SocketLoop) { + MultiProcessing mp; + int loop = 1000; + const int queue_size = 1; + AbstractChannelPtr mp_to_a1 = GetChannelFactory().GetChannel( + ChannelType::SOCKETCHANNEL, + queue_size, + sizeof(int64_t)*10000, + "mp_to_a1", + "mp_to_a1"); + AbstractChannelPtr a1_to_mp = GetChannelFactory().GetChannel( + ChannelType::SOCKETCHANNEL, + queue_size, + sizeof(int64_t)*10000, + "a1_to_mp", + "a1_to_mp"); + AbstractChannelPtr a1_to_a2 = GetChannelFactory().GetChannel( + ChannelType::SOCKETCHANNEL, + queue_size, + sizeof(int64_t)*10000, + "a1_to_a2", + "a1_to_a2"); + AbstractChannelPtr a2_to_a1 = GetChannelFactory().GetChannel( + ChannelType::SOCKETCHANNEL, + queue_size, + sizeof(int64_t)*10000, + "a2_to_a1", + "a2_to_a1"); + auto target_fn_a1 = std::bind(&soket_target_fn1, + loop, + mp_to_a1, + a1_to_mp, + a1_to_a2, + a2_to_a1); + auto target_fn_a2 = std::bind(&soket_target_fn2, + loop, + a1_to_a2, + a2_to_a1); + ProcessType actor1 = mp.BuildActor(target_fn_a1); + ProcessType actor2 = mp.BuildActor(target_fn_a2); + auto to_a1 = mp_to_a1->GetSendPort(); + to_a1->Start(); + auto from_a1 = a1_to_mp->GetRecvPort(); + from_a1->Start(); + MetaDataPtr metadata = std::make_shared(); + int64_t dims[] = {10000, 0, 0, 0, 0}; + int64_t nd = 1; + int64_t* array_ = reinterpret_cast + (malloc(sizeof(int64_t) * dims[0])); + memset(array_, 0, sizeof(int64_t) * dims[0]); + std::fill(array_, array_ + 10, 1); + GetMetadata(metadata, array_, nd, + static_cast(METADATA_TYPES::LONG), dims); + int expect_result = 1 + loop * 3; + const clock_t start_time = std::clock(); + while (loop--) { + LAVA_DUMP(LOG_UTTEST, "wait for response, remain loop: %d\n", loop); + to_a1->Send(metadata); + free(reinterpret_cast(metadata->mdata)); + metadata = from_a1->Recv(); + LAVA_DUMP(LOG_UTTEST, "metadata:\n"); + LAVA_DUMP(LOG_UTTEST, "nd: %ld\n", metadata->nd); + LAVA_DUMP(LOG_UTTEST, "type: %ld\n", metadata->type); + LAVA_DUMP(LOG_UTTEST, "elsize: %ld\n", metadata->elsize); + LAVA_DUMP(LOG_UTTEST, "total_size: %ld\n", metadata->total_size); + LAVA_DUMP(LOG_UTTEST, "dims: {%ld, %ld, %ld, %ld, %ld}\n", + metadata->dims[0], metadata->dims[1], metadata->dims[2], + metadata->dims[3], metadata->dims[4]); + LAVA_DUMP(LOG_UTTEST, "strides: {%ld, %ld, %ld, %ld, %ld}\n", + metadata->strides[0], metadata->strides[1], metadata->strides[2], + metadata->strides[3], metadata->strides[4]); + LAVA_DUMP(LOG_UTTEST, "mdata: %p, *mdata: %ld\n", metadata->mdata, + *reinterpret_cast(metadata->mdata)); + } + const clock_t end_time = std::clock(); + to_a1->Join(); + from_a1->Join(); + int64_t result = *reinterpret_cast(metadata->mdata); + free(reinterpret_cast(metadata->mdata)); + LAVA_DUMP(LOG_UTTEST, "socket result =%ld", result); + mp.Stop(); + mp.Cleanup(true); + if (result != expect_result) { + LAVA_DUMP(LOG_UTTEST, "expect_result: %d\n", expect_result); + LAVA_DUMP(LOG_UTTEST, "result: %ld\n", result); + LAVA_LOG_ERR("result != expect_result\n"); + } + std::printf("socket cpp loop timedelta: %f\n", + ((end_time - start_time)/static_cast(CLOCKS_PER_SEC))); + LAVA_DUMP(LOG_UTTEST, "exit\n"); +} + +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/_c_message_infrastructure/tools/log_merge.sh b/src/lava/magma/runtime/_c_message_infrastructure/tools/log_merge.sh new file mode 100755 index 000000000..6ff136ac5 --- /dev/null +++ b/src/lava/magma/runtime/_c_message_infrastructure/tools/log_merge.sh @@ -0,0 +1,11 @@ +#/bin/bash +LOG_FILE_FULL_NAME="lava_message_infrastructure.log" +LOG_FILE_SCATTERED_PATH=$1 +LOG_LEVEL=$2 +if [ "$LOG_LEVEL" = "" ];then + LOG_LEVEL="ERRO|INFO|WARN|DBUG|DUMP" +fi +if [ "$LOG_FILE_SCATTERED_PATH" = "" ];then + LOG_FILE_SCATTERED_PATH=$MSG_LOG_PATH +fi +grep -E "$LOG_LEVEL" "$LOG_FILE_SCATTERED_PATH"/lava_message_infrastructure_pid_*.log | sort -n > "$LOG_FILE_SCATTERED_PATH"/"$LOG_FILE_FULL_NAME" diff --git a/src/lava/magma/runtime/message_infrastructure/__init__.py b/src/lava/magma/runtime/message_infrastructure/__init__.py new file mode 100644 index 000000000..06535f73a --- /dev/null +++ b/src/lava/magma/runtime/message_infrastructure/__init__.py @@ -0,0 +1,114 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + +import os +import platform +from glob import glob +import warnings + + +def _get_pure_py() -> bool: + pure_py_env = os.getenv("LAVA_PURE_PYTHON", 0) + system_name = platform.system().lower() + if system_name != "linux": + return True + return int(pure_py_env) > 0 + + +PURE_PYTHON_VERSION = _get_pure_py() + +if PURE_PYTHON_VERSION: + from abc import ABC, abstractmethod + import multiprocessing as mp + + if platform.system() != 'Windows': + mp.set_start_method('fork') + + class Channel(ABC): + @property + @abstractmethod + def src_port(self): + pass + + @property + @abstractmethod + def dst_port(self): + pass + + from .py_ports import AbstractTransferPort + from .pypychannel import ( + SendPort, + RecvPort, + create_channel) + from .pypychannel import CspSelector as Selector + SupportGRPCChannel = False + SupportFastDDSChannel = False + SupportCycloneDDSChannel = False + SupportTempChannel = False + + def getTempSendPort(addr_path: str): + return None + + def getTempRecvPort(): + return None, None + +else: + from ctypes import CDLL, RTLD_GLOBAL + + def load_library(): + lib_name = 'libmessage_infrastructure.so' + here = os.path.abspath(__file__) + lib_path = os.path.join(os.path.dirname(here), lib_name) + + if not os.path.exists(lib_path): + warnings.warn("No library file") + return + + extra_lib_folder = os.path.join(os.path.dirname(here), "install", "lib") + dds_libs = ["libfastcdr.so.*", + "libfastrtps.so.*", + "libddsc.so.*", + "libddscxx.so.*"] + if os.path.exists(extra_lib_folder): + for lib in dds_libs: + files = glob(os.path.join(extra_lib_folder, lib)) + for file in files: + CDLL(file, mode=RTLD_GLOBAL) + + CDLL(lib_path, mode=RTLD_GLOBAL) + + load_library() + # noinspection PyUnresolvedReferences + from lava.magma.runtime.message_infrastructure. \ + MessageInfrastructurePywrapper import ( # noqa # nosec + RecvPort, # noqa # nosec + AbstractTransferPort, # noqa # nosec + support_grpc_channel, + support_fastdds_channel, + support_cyclonedds_channel) + + ChannelQueueSize = 1 + SyncChannelBytes = 128 + SelectorSleepNs = 1 + + from .ports import ( # noqa # nosec + SendPort, # noqa # nosec + Channel, # noqa # nosec + Selector, # noqa # nosec + getTempSendPort, # noqa # nosec + getTempRecvPort, # noqa # nosec + create_channel) # noqa # nosec + SupportGRPCChannel = support_grpc_channel() + SupportFastDDSChannel = support_fastdds_channel() + SupportCycloneDDSChannel = support_cyclonedds_channel() + SupportTempChannel = True + + if SupportGRPCChannel: + from .ports import GetRPCChannel # noqa # nosec + if SupportFastDDSChannel or SupportCycloneDDSChannel: + from .ports import GetDDSChannel # noqa # nosec + from lava.magma.runtime.message_infrastructure. \ + MessageInfrastructurePywrapper import ( + DDSTransportType, # noqa # nosec + DDSBackendType) # noqa # nosec diff --git a/src/lava/magma/runtime/message_infrastructure/factory.py b/src/lava/magma/runtime/message_infrastructure/factory.py index 4aa392ce6..ea175899c 100644 --- a/src/lava/magma/runtime/message_infrastructure/factory.py +++ b/src/lava/magma/runtime/message_infrastructure/factory.py @@ -1,20 +1,28 @@ # Copyright (C) 2021-22 Intel Corporation # SPDX-License-Identifier: LGPL 2.1 or later # See: https://spdx.org/licenses/ - -from lava.magma.core.process.message_interface_enum import ActorType -from lava.magma.runtime.message_infrastructure.multiprocessing import \ - MultiProcessing +from lava.magma.runtime.message_infrastructure.message_interface_enum \ + import ActorType +from lava.magma.runtime.message_infrastructure import PURE_PYTHON_VERSION class MessageInfrastructureFactory: - """Factory class to create the messaging infrastructure""" + """Creates the message infrastructure instance based on type""" @staticmethod def create(factory_type: ActorType): """Creates the message infrastructure instance based on type of actor framework being chosen.""" + if PURE_PYTHON_VERSION: + factory_type = ActorType.PyMultiProcessing + """type of actor framework being chosen""" # pylint: disable=W0105 if factory_type == ActorType.MultiProcessing: + from lava.magma.runtime.message_infrastructure.multiprocessing \ + import MultiProcessing + return MultiProcessing() + elif factory_type == ActorType.PyMultiProcessing: + from lava.magma.runtime.message_infrastructure.py_multiprocessing \ + import MultiProcessing return MultiProcessing() else: raise Exception("Unsupported factory_type") diff --git a/src/lava/magma/runtime/message_infrastructure/interfaces.py b/src/lava/magma/runtime/message_infrastructure/interfaces.py new file mode 100644 index 000000000..a1c24b463 --- /dev/null +++ b/src/lava/magma/runtime/message_infrastructure/interfaces.py @@ -0,0 +1,19 @@ +# Copyright (C) 2021-22 Intel Corporation +# SPDX-License-Identifier: LGPL 2.1 or later +# See: https://spdx.org/licenses/ + +from enum import IntEnum + + +class ChannelType(IntEnum): + """Type of a channel given the two process models""" + + PyPy = 0 + CPy = 1 + PyC = 2 + CNc = 3 + NcC = 4 + CC = 3 + NcNc = 5 + NcPy = 6 + PyNc = 7 diff --git a/src/lava/magma/runtime/message_infrastructure/message_infrastructure_interface.py b/src/lava/magma/runtime/message_infrastructure/message_infrastructure_interface.py index 7123e316e..3da3d441c 100644 --- a/src/lava/magma/runtime/message_infrastructure/message_infrastructure_interface.py +++ b/src/lava/magma/runtime/message_infrastructure/message_infrastructure_interface.py @@ -3,15 +3,9 @@ # See: https://spdx.org/licenses/ import typing as ty -if ty.TYPE_CHECKING: - from lava.magma.core.process.process import AbstractProcess - from lava.magma.compiler.builders.py_builder import PyProcessBuilder - from lava.magma.compiler.builders.runtimeservice_builder import \ - RuntimeServiceBuilder from abc import ABC, abstractmethod - -from lava.magma.compiler.channels.interfaces import ChannelType, Channel -from lava.magma.core.sync.domain import SyncDomain +from lava.magma.runtime.message_infrastructure import Channel +from lava.magma.runtime.message_infrastructure.interfaces import ChannelType class MessageInfrastructureInterface(ABC): @@ -20,26 +14,42 @@ class MessageInfrastructureInterface(ABC): declare the underlying Channel Infrastructure Class to be used for message passing implementation.""" + @abstractmethod + def init(self): + """Init the messaging infrastructure""" + pass # pylint: disable=W0107 + @abstractmethod def start(self): """Starts the messaging infrastructure""" - @abstractmethod + def pre_stop(self): + """Stop MessageInfrastructure before join ports""" + pass # pylint: disable=W0107 + def stop(self): - """Stops the messaging infrastructure""" + """Stops the messaging infrastructure after join ports""" + pass # pylint: disable=W0107 @abstractmethod - def build_actor(self, target_fn: ty.Callable, builder: ty.Union[ - ty.Dict['AbstractProcess', 'PyProcessBuilder'], ty.Dict[ - SyncDomain, 'RuntimeServiceBuilder']]): + def build_actor(self, target_fn: ty.Callable, builder): """Given a target_fn starts a system process""" + def cleanup(self, block=False): + """Close all resources""" + pass # pylint: disable=W0107 + + def trace(self, logger) -> int: + """Trace actors' exceptions""" + return 0 + @property @abstractmethod def actors(self) -> ty.List[ty.Any]: """Returns a list of actors""" @abstractmethod - def channel_class(self, channel_type: ChannelType) -> ty.Type[Channel]: + def channel(self, channel_type: ChannelType, src_name, dst_name, + shape, dtype, size, sync=False) -> Channel: """Given the Channel Type, Return the Channel Implementation to be used during execution""" diff --git a/src/lava/magma/core/process/message_interface_enum.py b/src/lava/magma/runtime/message_infrastructure/message_interface_enum.py similarity index 88% rename from src/lava/magma/core/process/message_interface_enum.py rename to src/lava/magma/runtime/message_infrastructure/message_interface_enum.py index ae33c90fa..704ab56e3 100644 --- a/src/lava/magma/core/process/message_interface_enum.py +++ b/src/lava/magma/runtime/message_infrastructure/message_interface_enum.py @@ -7,3 +7,4 @@ class ActorType(IntEnum): MultiProcessing = 0 + PyMultiProcessing = 1 diff --git a/src/lava/magma/runtime/message_infrastructure/multiprocessing.py b/src/lava/magma/runtime/message_infrastructure/multiprocessing.py index 7c9e5aed9..a47996f89 100644 --- a/src/lava/magma/runtime/message_infrastructure/multiprocessing.py +++ b/src/lava/magma/runtime/message_infrastructure/multiprocessing.py @@ -3,41 +3,28 @@ # See: https://spdx.org/licenses/ import typing as ty -if ty.TYPE_CHECKING: - from lava.magma.core.process.process import AbstractProcess - from lava.magma.compiler.builders.py_builder import PyProcessBuilder - from lava.magma.compiler.builders.runtimeservice_builder import \ - RuntimeServiceBuilder +import numpy as np +from functools import partial + +from lava.magma.runtime.message_infrastructure.MessageInfrastructurePywrapper \ + import CppMultiProcessing +from lava.magma.runtime.message_infrastructure.MessageInfrastructurePywrapper \ + import ChannelType as ChannelBackend # noqa: E402 +from lava.magma.runtime.message_infrastructure \ + import Channel, ChannelQueueSize, SyncChannelBytes +from lava.magma.runtime.message_infrastructure.interfaces import ChannelType +from lava.magma.runtime.message_infrastructure. \ + message_infrastructure_interface import MessageInfrastructureInterface import multiprocessing as mp -import os import traceback -from lava.magma.compiler.channels.interfaces import ChannelType, Channel -from lava.magma.compiler.channels.pypychannel import PyPyChannel -from lava.magma.runtime.message_infrastructure.shared_memory_manager import ( - SharedMemoryManager, -) - try: - from lava.magma.compiler.channels.cpychannel import \ - CPyChannel, PyCChannel - from lava.magma.compiler.channels.pyncchannel import PyNcChannel - from lava.magma.compiler.channels.ncpychannel import NcPyChannel + from lava.magma.core.model.c.type import LavaTypeTransfer except ImportError: - class CPyChannel: - pass - - class PyCChannel: - pass - - class PyNcChannel: - pass - - class NcPyChannel: + class LavaTypeTransfer: pass -from lava.magma.core.sync.domain import SyncDomain from lava.magma.runtime.message_infrastructure.message_infrastructure_interface\ import MessageInfrastructureInterface @@ -93,54 +80,57 @@ class MultiProcessing(MessageInfrastructureInterface): """Implements message passing using shared memory and multiprocessing""" def __init__(self): - self._smm: ty.Optional[SharedMemoryManager] = None - self._actors: ty.List[SystemProcess] = [] + self._mp: ty.Optional[CppMultiProcessing] = CppMultiProcessing() @property def actors(self): """Returns a list of actors""" - return self._actors + return self._mp.get_actors() - @property - def smm(self): - """Returns the underlying shared memory manager""" - return self._smm + def init(self): + pass def start(self): - """Starts the shared memory manager""" - self._smm = SharedMemoryManager() - self._smm.start() + """Init the MultiProcessing""" + for actor in self._mp.get_actors(): + actor.start() - def build_actor(self, target_fn: ty.Callable, builder: ty.Union[ - ty.Dict['AbstractProcess', 'PyProcessBuilder'], ty.Dict[ - SyncDomain, 'RuntimeServiceBuilder']]) -> ty.Any: + def build_actor(self, target_fn: ty.Callable, builder) -> ty.Any: """Given a target_fn starts a system (os) process""" - system_process = SystemProcess(target=target_fn, - args=(), - kwargs={"builder": builder}) - system_process.start() - self._actors.append(system_process) - return system_process - - def stop(self): + bound_target_fn = partial(target_fn, builder=builder) + self._mp.build_actor(bound_target_fn) + + def pre_stop(self): """Stops the shared memory manager""" - for actor in self._actors: - if actor._parent_pid == os.getpid(): - actor.join() - self._smm.shutdown() - - def channel_class(self, channel_type: ChannelType) -> ty.Type[Channel]: - """Given a channel type, returns the shared memory based class - implementation for the same""" + self._mp.stop() + + def pause(self): + for actor in self._mp.get_actors(): + actor.pause() + + def cleanup(self, block=False): + """Close all resources""" + self._mp.cleanup(block) + + def trace(self, logger) -> int: + """Trace actors' exceptions""" + # CppMessageInfrastructure cannot trace exceptions. + # It needs to stop all actors. + self.stop() + return 0 + + def channel(self, channel_type: ChannelType, src_name, dst_name, + shape, dtype, size, sync=False) -> Channel: if channel_type == ChannelType.PyPy: - return PyPyChannel - elif channel_type == ChannelType.PyC: - return PyCChannel - elif channel_type == ChannelType.CPy: - return CPyChannel - elif channel_type == ChannelType.PyNc: - return PyNcChannel - elif channel_type == ChannelType.NcPy: - return NcPyChannel + channel_bytes = np.prod(shape) * np.dtype(dtype).itemsize \ + if not sync else SyncChannelBytes + return Channel(ChannelBackend.SHMEMCHANNEL, ChannelQueueSize, + channel_bytes, src_name, dst_name, shape, dtype) + elif channel_type == ChannelType.PyC or channel_type == ChannelType.CPy: + temp_dtype = LavaTypeTransfer.cdtype2numpy(dtype) + channel_bytes = np.prod(shape) * np.dtype(temp_dtype).itemsize \ + if not sync else SyncChannelBytes + return Channel(ChannelBackend.SHMEMCHANNEL, size, + channel_bytes, src_name, dst_name, shape, dtype) else: raise Exception(f"Unsupported channel type {channel_type}") diff --git a/src/lava/magma/runtime/message_infrastructure/nx.py b/src/lava/magma/runtime/message_infrastructure/nx.py index 1ab70e319..8b529c80e 100644 --- a/src/lava/magma/runtime/message_infrastructure/nx.py +++ b/src/lava/magma/runtime/message_infrastructure/nx.py @@ -4,11 +4,22 @@ import typing as ty -from lava.magma.compiler.channels.interfaces import ChannelType +from lava.magma.runtime.message_infrastructure.interfaces import \ + ChannelType from lava.magma.core.sync.domain import SyncDomain from lava.magma.runtime.message_infrastructure \ .message_infrastructure_interface import \ MessageInfrastructureInterface +from lava.magma.runtime.message_infrastructure import Channel +try: + from lava.magma.runtime.message_infrastructure.nccchannel import NcCChannel + from lava.magma.runtime.message_infrastructure.cncchannel import CNcChannel +except ImportError: + class CNcChannel: + pass + + class NcCChannel: + pass class NxBoardMsgInterface(MessageInfrastructureInterface): @@ -29,9 +40,9 @@ def build_actor(self, target_fn: ty.Callable, builder: ty.Union[ def stop(self): """Stops the shared memory manager""" - def channel_class(self, channel_type: ChannelType) -> ty.Type[ChannelType]: + def channel_class(self, channel_type: ChannelType) -> ty.Type[Channel]: """Given a channel type, returns the shared memory based class - implementation for the same.""" + implementation for the same""" if channel_type == ChannelType.CNc: return CNcChannel elif channel_type == ChannelType.NcC: diff --git a/src/lava/magma/runtime/message_infrastructure/ports.py b/src/lava/magma/runtime/message_infrastructure/ports.py new file mode 100644 index 000000000..cb60c0a88 --- /dev/null +++ b/src/lava/magma/runtime/message_infrastructure/ports.py @@ -0,0 +1,124 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + +from lava.magma.runtime.message_infrastructure \ + import ChannelQueueSize, SelectorSleepNs +from lava.magma.runtime.message_infrastructure.MessageInfrastructurePywrapper \ + import Channel as CppChannel +from lava.magma.runtime.message_infrastructure.MessageInfrastructurePywrapper \ + import ( + TempChannel, + support_grpc_channel, + support_fastdds_channel, + support_cyclonedds_channel, + AbstractTransferPort, + ChannelType, + CPPSelector, + RecvPort) + +import numpy as np +import typing as ty +import warnings + + +class Selector(CPPSelector): + def select(self, *args: ty.Tuple[RecvPort, ty.Callable[[], ty.Any]]): + return super().select(args, SelectorSleepNs) + + +class SendPort(AbstractTransferPort): + def __init__(self, send_port): + super().__init__() + self._cpp_send_port = send_port + + def send(self, data): + # TODO: Workaround for lava-loihi cpplib, need to change later + port_type = np.int32 if "LavaCDataType" in str(self.d_type) \ + else self.d_type + if data.dtype.type != np.str_ and \ + np.dtype(data.dtype).itemsize > np.dtype(port_type).itemsize: + warnings.warn("Sending data with miss matched dtype," + f"Transfer {data.dtype} to {port_type}") + data = data.astype(port_type) + # Use np.copy to handle slices input + self._cpp_send_port.send(np.copy(data)) + + def start(self): + self._cpp_send_port.start() + + def probe(self): + return self._cpp_send_port.probe() + + def join(self): + self._cpp_send_port.join() + + @property + def name(self): + return self._cpp_send_port.name + + @property + def shape(self): + return self._cpp_send_port.shape + + @property + def d_type(self): + return self._cpp_send_port.d_type + + @property + def size(self): + return self._cpp_send_port.size + + def get_channel_type(self): + return self._cpp_send_port.get_channel_type() + + +if support_grpc_channel(): + from lava.magma.runtime.message_infrastructure. \ + MessageInfrastructurePywrapper \ + import GetRPCChannel as CppRPCChannel + + class GetRPCChannel(CppRPCChannel): + + @property + def src_port(self): + return SendPort(super().src_port) + +if support_fastdds_channel() or support_cyclonedds_channel(): + from lava.magma.runtime.message_infrastructure. \ + MessageInfrastructurePywrapper \ + import GetDDSChannel as CppDDSChannel + + class GetDDSChannel(CppDDSChannel): + @property + def src_port(self): + return SendPort(super().src_port) + + +class Channel(CppChannel): + + @property + def src_port(self): + return SendPort(super().src_port) + + +def create_channel( + message_infrastructure: \ + "MessageInfrastructureInterface", # nosec # noqa + src_name, dst_name, shape, dtype, size): + channel_bytes = np.prod(shape) * np.dtype(dtype).itemsize + return Channel(ChannelType.SHMEMCHANNEL, ChannelQueueSize, channel_bytes, + src_name, dst_name, shape, dtype) + + +def getTempSendPort(addr_path: str): + tmp_channel = TempChannel(addr_path) + send_port = tmp_channel.src_port + return send_port + + +def getTempRecvPort(): + tmp_channel = TempChannel() + addr_path = tmp_channel.addr_path + recv_port = tmp_channel.dst_port + return addr_path, recv_port diff --git a/src/lava/magma/runtime/message_infrastructure/py_multiprocessing.py b/src/lava/magma/runtime/message_infrastructure/py_multiprocessing.py new file mode 100644 index 000000000..b1aab72f5 --- /dev/null +++ b/src/lava/magma/runtime/message_infrastructure/py_multiprocessing.py @@ -0,0 +1,153 @@ +# Copyright (C) 2021-22 Intel Corporation +# SPDX-License-Identifier: LGPL 2.1 or later +# See: https://spdx.org/licenses/ +import typing as ty +if ty.TYPE_CHECKING: + from lava.magma.core.process.process import AbstractProcess + from lava.magma.compiler.builders.py_builder import PyProcessBuilder + from lava.magma.compiler.builders.runtimeservice_builder import \ + RuntimeServiceBuilder + +import multiprocessing as mp +import os +import traceback + +from lava.magma.runtime.message_infrastructure.interfaces import ChannelType +from lava.magma.runtime.message_infrastructure import Channel +from lava.magma.runtime.message_infrastructure.pypychannel import PyPyChannel +from lava.magma.runtime.message_infrastructure.shared_memory_manager import ( + SharedMemoryManager, +) + +try: + from lava.magma.runtime.message_infrastructure.cpychannel import \ + CPyChannel, PyCChannel +except ImportError: + class CPyChannel: + pass + + class PyCChannel: + pass + +from lava.magma.core.sync.domain import SyncDomain +from lava.magma.runtime.message_infrastructure.message_infrastructure_interface\ + import MessageInfrastructureInterface + +# pylint: disable=W0105 +"""Implements the Message Infrastructure Interface using Python +MultiProcessing Library. The MultiProcessing API is used to create actors +which will participate in exchanging messages. The Channel Infrastructure +further uses the SharedMemoryManager from MultiProcessing Library to +implement the communication backend in this implementation.""" + + +class SystemProcess(mp.Process): + """Wraps a process so that the exceptions can be collected if present""" + + def __init__(self, *args, **kwargs): + mp.Process.__init__(self, *args, **kwargs) + self._pconn, self._cconn = mp.Pipe() + self._exception = None + self._is_done = False + + def run(self): + try: + mp.Process.run(self) + self._cconn.send(None) + except Exception as e: + tb = traceback.format_exc() + self._cconn.send((e, tb)) + + def join(self, timeout=None): + if not self._is_done: + super().join() + super().close() + if self._pconn.poll(): + self._exception = self._pconn.recv() + self._cconn.close() + self._pconn.close() + self._is_done = True + + @property + def exception(self): + """Exception property.""" + if not self._is_done and self._pconn.poll(): + self._exception = self._pconn.recv() + return self._exception + + def close_pipe(self): + self._cconn.close() + self._pconn.close() + + +class MultiProcessing(MessageInfrastructureInterface): + """Implements message passing using shared memory and multiprocessing""" + + def __init__(self): + self._smm: ty.Optional[SharedMemoryManager] = None + self._actors: ty.List[SystemProcess] = [] + + @property + def actors(self): + """Returns a list of actors""" + return self._actors + + @property + def smm(self): + """Returns the underlying shared memory manager""" + return self._smm + + def init(self): + """Starts the shared memory manager""" + self._smm = SharedMemoryManager() + self._smm.start() + + def start(self): + pass + + def build_actor(self, target_fn: ty.Callable, builder: ty.Union[ + ty.Dict['AbstractProcess', 'PyProcessBuilder'], ty.Dict[ + SyncDomain, 'RuntimeServiceBuilder']]) -> ty.Any: + """Given a target_fn starts a system (os) process""" + system_process = SystemProcess(target=target_fn, + args=(), + kwargs={"builder": builder}) + system_process.start() + self._actors.append(system_process) + return system_process + + def stop(self): + """Stops the shared memory manager""" + for actor in self._actors: + if actor._parent_pid == os.getpid(): + actor.join() + actor.close_pipe() + self._smm.shutdown() + + def trace(self, logger) -> int: + error_cnt = 0 + for actors in self._actors: + actors.join() + if actors.exception: + _, traceback = actors.exception + logger.info(traceback) + error_cnt += 1 + actors.close_pipe() + return error_cnt + + def channel_class(self, channel_type: ChannelType) -> ty.Type[Channel]: + """Given a channel type, returns the shared memory based class + implementation for the same""" + if channel_type == ChannelType.PyPy: + return PyPyChannel + elif channel_type == ChannelType.PyC: + return PyCChannel + elif channel_type == ChannelType.CPy: + return CPyChannel + else: + raise Exception(f"Unsupported channel type {channel_type}") + + def channel(self, channel_type: ChannelType, src_name, dst_name, + shape, dtype, size, sync=False) -> Channel: + channel_class = self.channel_class(channel_type) + return channel_class(self, src_name, dst_name, shape, dtype, size) diff --git a/src/lava/magma/compiler/channels/interfaces.py b/src/lava/magma/runtime/message_infrastructure/py_ports.py similarity index 58% rename from src/lava/magma/compiler/channels/interfaces.py rename to src/lava/magma/runtime/message_infrastructure/py_ports.py index e7c533dfe..0227de8e3 100644 --- a/src/lava/magma/compiler/channels/interfaces.py +++ b/src/lava/magma/runtime/message_infrastructure/py_ports.py @@ -1,15 +1,14 @@ -# Copyright (C) 2021-22 Intel Corporation +# Copyright (C) 2022 Intel Corporation # SPDX-License-Identifier: LGPL 2.1 or later # See: https://spdx.org/licenses/ import typing as ty from abc import ABC, abstractmethod -from enum import IntEnum import numpy as np -class AbstractCspPort(ABC): +class AbstractTransferPort(ABC): """Abstract base class for CSP channel.""" @property @@ -44,39 +43,13 @@ def is_msg_size_static(self) -> bool: return True -class AbstractCspSendPort(AbstractCspPort): +class AbstractSendPort(AbstractTransferPort): @abstractmethod def send(self, data: np.ndarray): pass -class AbstractCspRecvPort(AbstractCspPort): +class AbstractRecvPort(AbstractTransferPort): @abstractmethod def recv(self) -> np.ndarray: pass - - -class Channel(ABC): - @property - @abstractmethod - def src_port(self) -> AbstractCspSendPort: - pass - - @property - @abstractmethod - def dst_port(self) -> AbstractCspRecvPort: - pass - - -class ChannelType(IntEnum): - """Type of a channel given the two process models""" - - PyPy = 0 - CPy = 1 - PyC = 2 - CNc = 3 - NcC = 4 - CC = 3 - NcNc = 5 - NcPy = 6 - PyNc = 7 diff --git a/src/lava/magma/compiler/channels/pypychannel.py b/src/lava/magma/runtime/message_infrastructure/pypychannel.py similarity index 88% rename from src/lava/magma/compiler/channels/pypychannel.py rename to src/lava/magma/runtime/message_infrastructure/pypychannel.py index cc4a51e4c..db80be3e4 100644 --- a/src/lava/magma/compiler/channels/pypychannel.py +++ b/src/lava/magma/runtime/message_infrastructure/pypychannel.py @@ -10,20 +10,23 @@ from time import time from scipy.sparse import csr_matrix from lava.utils.sparse import find -from lava.magma.compiler.channels.watchdog import Watchdog, NoOPWatchdog +from lava.magma.runtime.message_infrastructure.watchdog import \ + Watchdog, NoOPWatchdog import numpy as np -from lava.magma.compiler.channels.interfaces import ( - Channel, - AbstractCspSendPort, - AbstractCspRecvPort, +from lava.magma.runtime.message_infrastructure import Channel +from lava.magma.runtime.message_infrastructure.py_ports import ( + AbstractSendPort, + AbstractRecvPort, + AbstractTransferPort, ) if ty.TYPE_CHECKING: - from lava.magma.runtime.message_infrastructure \ - .message_infrastructure_interface import ( - MessageInfrastructureInterface) + from lava.magma.runtime.message_infrastructure. \ + message_infrastructure_interface \ + import ( + MessageInfrastructureInterface) # silence pyflakes @dataclass @@ -33,7 +36,7 @@ class Proto: nbytes: int -class CspSendPort(AbstractCspSendPort): +class SendPort(AbstractSendPort): """ CspSendPort is a low level send port implementation based on CSP semantics. It can be understood as the input port of a CSP channel. @@ -93,9 +96,7 @@ def start(self): shape=self._shape, dtype=self._dtype, buffer=self._shm.buf[ - self._nbytes * i: self._nbytes * (i + 1) - ], - ) + self._nbytes * i: self._nbytes * (i + 1)],) for i in range(self._size) ] self._semaphore = BoundedSemaphore(self._size) @@ -113,7 +114,7 @@ def _ack_callback(self, ack): not_full = self.probe() self._semaphore.release() if self.observer and not not_full: - self.observer() + self.observer() # pylint: disable=E1102 except EOFError: pass @@ -187,7 +188,7 @@ def get(self, block=True, timeout=None, peek=False): return item -class CspRecvPort(AbstractCspRecvPort): +class RecvPort(AbstractRecvPort): """ CspRecvPort is a low level recv port implementation based on CSP semantics. It can be understood as the output port of a CSP channel. @@ -250,8 +251,7 @@ def start(self): shape=self._shape, dtype=self._dtype, buffer=self._shm.buf[ - self._nbytes * i: self._nbytes * (i + 1) - ], + self._nbytes * i: self._nbytes * (i + 1)], ) for i in range(self._size) ] @@ -270,7 +270,7 @@ def _req_callback(self, req): not_empty = self.probe() self._queue.put_nowait(0) if self.observer and not not_empty: - self.observer() + self.observer() # pylint: disable=E1102 except EOFError: pass @@ -334,8 +334,7 @@ def _set_observer( def select( self, *channel_actions: ty.Tuple[ - ty.Union[CspSendPort, CspRecvPort], - ty.Callable[[], ty.Any] + ty.Union[SendPort, RecvPort], ty.Callable[[], ty.Any] ], ) -> None: """ @@ -386,20 +385,22 @@ def __init__( req = Semaphore(0) ack = Semaphore(0) proto = Proto(shape=shape, dtype=dtype, nbytes=nbytes) - self._src_port = CspSendPort(src_name, shm, proto, size, req, ack, - src_send_watchdog, - src_join_watchdog) - self._dst_port = CspRecvPort(dst_name, shm, proto, size, req, ack, - dst_recv_watchdog, - dst_join_watchdog) + self._src_port = SendPort(src_name, shm, proto, size, req, ack) + self._dst_port = RecvPort(dst_name, shm, proto, size, req, ack) def nbytes(self, shape, dtype): return np.prod(shape) * np.dtype(dtype).itemsize @property - def src_port(self) -> AbstractCspSendPort: + def src_port(self) -> AbstractTransferPort: return self._src_port @property - def dst_port(self) -> AbstractCspRecvPort: + def dst_port(self) -> AbstractTransferPort: return self._dst_port + + +def create_channel(message_infrastructure: "MessageInfrastructureInterface", + src_name, dst_name, shape, dtype, size): + return PyPyChannel(message_infrastructure, src_name, dst_name, + shape, dtype, size) diff --git a/src/lava/magma/compiler/channels/watchdog.py b/src/lava/magma/runtime/message_infrastructure/watchdog.py similarity index 100% rename from src/lava/magma/compiler/channels/watchdog.py rename to src/lava/magma/runtime/message_infrastructure/watchdog.py diff --git a/src/lava/magma/runtime/runtime.py b/src/lava/magma/runtime/runtime.py index d5dde5a92..a1db67de8 100644 --- a/src/lava/magma/runtime/runtime.py +++ b/src/lava/magma/runtime/runtime.py @@ -8,15 +8,24 @@ import sys import traceback import typing as ty - import numpy as np +from lava.magma.runtime.message_infrastructure import (RecvPort, + SendPort, + Channel, + SupportTempChannel, + Selector, + getTempSendPort, + getTempRecvPort, + AbstractTransferPort) + from scipy.sparse import csr_matrix from lava.magma.compiler.var_model import AbstractVarModel, LoihiSynapseVarModel -from lava.magma.core.process.message_interface_enum import ActorType +from lava.magma.runtime.message_infrastructure.message_interface_enum import \ + ActorType from lava.magma.runtime.message_infrastructure.factory import \ MessageInfrastructureFactory -from lava.magma.runtime.message_infrastructure. \ - message_infrastructure_interface import \ +from lava.magma.runtime. \ + message_infrastructure.message_infrastructure_interface import \ MessageInfrastructureInterface from lava.magma.runtime.mgmt_token_enums import (MGMT_COMMAND, MGMT_RESPONSE, enum_equal, enum_to_np) @@ -25,8 +34,6 @@ if ty.TYPE_CHECKING: from lava.magma.core.process.process import AbstractProcess -from lava.magma.compiler.channels.pypychannel import CspRecvPort, CspSendPort, \ - CspSelector from lava.magma.compiler.builders.channel_builder import ( ChannelBuilderMp, RuntimeChannelBuilderMp, ServiceChannelBuilderMp, ChannelBuilderPyNc) @@ -34,14 +41,15 @@ from lava.magma.compiler.builders.py_builder import PyProcessBuilder from lava.magma.compiler.builders.runtimeservice_builder import \ RuntimeServiceBuilder -from lava.magma.compiler.channels.interfaces import AbstractCspPort, Channel, \ +from lava.magma.runtime.message_infrastructure.interfaces import \ ChannelType from lava.magma.compiler.executable import Executable from lava.magma.compiler.node import NodeConfig from lava.magma.core.process.ports.ports import create_port_id from lava.magma.core.run_conditions import (AbstractRunCondition, RunContinuous, RunSteps) -from lava.magma.compiler.channels.watchdog import WatchdogManagerInterface +from lava.magma.runtime.message_infrastructure.watchdog import \ + WatchdogManagerInterface """Defines a Runtime which takes a lava executable and a pluggable message passing infrastructure (for instance multiprocessing+shared memory or ray in @@ -128,9 +136,9 @@ def __init__(self, self._is_started: bool = False self._req_paused: bool = False self._req_stop: bool = False - self.runtime_to_service: ty.Iterable[CspSendPort] = [] - self.service_to_runtime: ty.Iterable[CspRecvPort] = [] - self._open_ports: ty.List[AbstractCspPort] = [] + self.runtime_to_service: ty.Iterable[SendPort] = [] + self.service_to_runtime: ty.Iterable[RecvPort] = [] + self._open_ports: ty.List[AbstractTransferPort] = [] self.num_steps: int = 0 self._watchdog_manager = None @@ -186,7 +194,7 @@ def _build_message_infrastructure(self): _messaging_infrastructure_type and Start it""" self._messaging_infrastructure = MessageInfrastructureFactory.create( self._messaging_infrastructure_type) - self._messaging_infrastructure.start() + self._messaging_infrastructure.init() def _get_process_builder_for_process(self, process: AbstractProcess) -> \ AbstractProcessBuilder: @@ -310,15 +318,15 @@ def _get_resp_for_run(self): Gets response from RuntimeServices """ if self._is_running: - selector = CspSelector() + selector = Selector() # Poll on all responses channel_actions = [(recv_port, (lambda y: (lambda: y))( - recv_port)) for - recv_port in - self.service_to_runtime] + recv_port)) for recv_port in self.service_to_runtime] rsps = [] while True: recv_port = selector.select(*channel_actions) + if recv_port is None: + continue data = recv_port.recv() rsps.append(data) if enum_equal(data, MGMT_RESPONSE.REQ_PAUSE): @@ -330,14 +338,8 @@ def _get_resp_for_run(self): elif not enum_equal(data, MGMT_RESPONSE.DONE): if enum_equal(data, MGMT_RESPONSE.ERROR): # Receive all errors from the ProcessModels - error_cnt = 0 - for actors in \ - self._messaging_infrastructure.actors: - actors.join() - if actors.exception: - _, traceback = actors.exception - self.log.info(traceback) - error_cnt += 1 + error_cnt = self._messaging_infrastructure.trace( + self.log) raise RuntimeError( f"{error_cnt} Exception(s) occurred. See " f"output above for details.") @@ -370,6 +372,7 @@ def _run(self, run_condition: AbstractRunCondition): """ if self._is_started: self._is_running = True + self._messaging_infrastructure.start() if isinstance(run_condition, RunSteps): self.num_steps = run_condition.num_steps for send_port in self.runtime_to_service: @@ -435,9 +438,12 @@ def stop(self): data = recv_port.recv() if not enum_equal(data, MGMT_RESPONSE.TERMINATED): raise RuntimeError(f"Runtime Received {data}") + + self._messaging_infrastructure.pre_stop() self.join() self._is_running = False self._is_started = False + self._messaging_infrastructure.cleanup(True) # Send messages to RuntimeServices to stop as soon as possible. else: self.log.info("Runtime not started yet.") @@ -481,28 +487,36 @@ def set_var(self, var_id: int, value: np.ndarray, idx: np.ndarray = None): # from a model with model_id and var with var_id # 1. Send SET Command - req_port: CspSendPort = self.runtime_to_service[runtime_srv_id] + req_port: SendPort = self.runtime_to_service[runtime_srv_id] req_port.send(MGMT_COMMAND.SET_DATA) req_port.send(enum_to_np(model_id)) req_port.send(enum_to_np(var_id)) - rsp_port: CspRecvPort = self.service_to_runtime[runtime_srv_id] + rsp_port: RecvPort = self.service_to_runtime[runtime_srv_id] # 2. Reshape the data buffer: np.ndarray = value if idx: buffer = buffer[idx] - buffer_shape: ty.Tuple[int, ...] = buffer.shape - num_items: int = np.prod(buffer_shape).item() - reshape_order = 'F' if isinstance( - ev, LoihiSynapseVarModel) else 'C' - buffer = buffer.reshape((1, num_items), order=reshape_order) - - # 3. Send [NUM_ITEMS, DATA1, DATA2, ...] - data_port: CspSendPort = self.runtime_to_service[runtime_srv_id] - data_port.send(enum_to_np(num_items)) - for i in range(num_items): - data_port.send(enum_to_np(buffer[0, i], np.float64)) + + if SupportTempChannel: + addr_path = rsp_port.recv() + send_port = getTempSendPort(str(addr_path[0])) + send_port.start() + send_port.send(buffer) + send_port.join() + else: + # 3. Send [NUM_ITEMS, DATA1, DATA2, ...] + buffer_shape: ty.Tuple[int, ...] = buffer.shape + num_items: int = np.prod(buffer_shape).item() + reshape_order = 'F' if isinstance(ev, LoihiSynapseVarModel) \ + else 'C' + buffer = buffer.reshape((1, num_items), order=reshape_order) + data_port: SendPort = self.runtime_to_service[runtime_srv_id] + data_port.send(enum_to_np(num_items)) + for i in range(num_items): + data_port.send(enum_to_np(buffer[0, i], np.float64)) + rsp = rsp_port.recv() if not enum_equal(rsp, MGMT_RESPONSE.SET_COMPLETE): raise RuntimeError("Var Set couldn't get successfully " @@ -534,31 +548,40 @@ def get_var(self, var_id: int, idx: np.ndarray = None) -> np.ndarray: # from a model with model_id and var with var_id # 1. Send GET Command - req_port: CspSendPort = self.runtime_to_service[runtime_srv_id] + req_port: SendPort = self.runtime_to_service[runtime_srv_id] req_port.send(MGMT_COMMAND.GET_DATA) req_port.send(enum_to_np(model_id)) req_port.send(enum_to_np(var_id)) - # 2. Receive Data [NUM_ITEMS, DATA1, DATA2, ...] - data_port: CspRecvPort = self.service_to_runtime[runtime_srv_id] - num_items: int = int(data_port.recv()[0].item()) - - if ev.dtype == csr_matrix: - buffer = np.zeros(num_items) - + if SupportTempChannel: + addr_path, recv_port = getTempRecvPort() + recv_port.start() + req_port.send(np.array([addr_path])) + buffer = recv_port.recv() + recv_port.join() + if ev.dtype == csr_matrix: + return buffer[idx] if idx else buffer + if buffer.dtype.type != np.str_: + reshape_order = 'F' \ + if isinstance(ev, LoihiSynapseVarModel) else 'C' + buffer = buffer.ravel(order=reshape_order).reshape(ev.shape) + else: + # 2. Receive Data [NUM_ITEMS, DATA1, DATA2, ...] + data_port: RecvPort = self.service_to_runtime[runtime_srv_id] + num_items: int = int(data_port.recv()[0].item()) + if ev.dtype == csr_matrix: + buffer = np.zeros(num_items) + for i in range(num_items): + buffer[i] = data_port.recv()[0] + return buffer[idx] if idx else buffer + buffer: np.ndarray = np.zeros((1, np.prod(ev.shape))) for i in range(num_items): - buffer[i] = data_port.recv()[0] - - return buffer[idx] if idx else buffer - - buffer: np.ndarray = np.zeros((1, np.prod(ev.shape))) - for i in range(num_items): - buffer[0, i] = data_port.recv()[0] + buffer[0, i] = data_port.recv()[0] + # 3. Reshape result and return + reshape_order = 'F' if isinstance(ev, LoihiSynapseVarModel) \ + else 'C' + buffer = buffer.reshape(ev.shape, order=reshape_order) - # 3. Reshape result and return - reshape_order = 'F' if isinstance( - ev, LoihiSynapseVarModel) else 'C' - buffer = buffer.reshape(ev.shape, order=reshape_order) if idx: return buffer[idx] else: diff --git a/src/lava/magma/runtime/runtime_services/channel_broker/channel_broker.py b/src/lava/magma/runtime/runtime_services/channel_broker/channel_broker.py index 12da2257c..0a439a5b7 100644 --- a/src/lava/magma/runtime/runtime_services/channel_broker/channel_broker.py +++ b/src/lava/magma/runtime/runtime_services/channel_broker/channel_broker.py @@ -9,8 +9,12 @@ import numpy as np import typing as ty -from lava.magma.compiler.channels.interfaces import AbstractCspPort -from lava.magma.compiler.channels.pypychannel import CspSelector, PyPyChannel +from lava.magma.runtime.message_infrastructure import ( + AbstractTransferPort, + create_channel, + Selector +) +from lava.magma.runtime.message_infrastructure import Channel as MsgChannel from lava.magma.runtime.message_infrastructure.shared_memory_manager import ( SharedMemoryManager, ) @@ -58,7 +62,7 @@ def __init__(self, def generate_channel_name(prefix: str, port_idx: int, - csp_port: AbstractCspPort, + csp_port: AbstractTransferPort, c_builder_idx: int) -> str: return f"{prefix}{str(port_idx)}_{str(csp_port.name)}_{str(c_builder_idx)}" @@ -101,7 +105,7 @@ def __init__(self, self.c_outports_to_poll: ty.Dict[Channel, COutPort] = {} self.smm: SharedMemoryManager = SharedMemoryManager() - self.mgmt_channel: ty.Optional[PyPyChannel] = None + self.mgmt_channel: ty.Optional[MsgChannel] = None self.grpc_stopping_event: ty.Optional[threading.Event] = None self.port_poller: ty.Optional[threading.Thread] = None self.grpc_poller: ty.Optional[threading.Thread] = None @@ -110,13 +114,15 @@ def run(self): """Start the polling threads""" if not self.has_started: self.smm.start() - self.mgmt_channel = PyPyChannel( + + self.mgmt_channel = create_channel( message_infrastructure=self, src_name="mgmt_channel", dst_name="mgmt_channel", shape=(1,), dtype=np.int32, - size=1) + size=1, + ) self.mgmt_channel.src_port.start() self.mgmt_channel.dst_port.start() self.port_poller = threading.Thread(target=self.poll_c_inports) @@ -143,7 +149,7 @@ def poll_c_inports(self): After sending requests to the GRPC channel the process is informed about completion. """ - selector = CspSelector() + selector = Selector() while True: # Need to poll both GRPC and CSP ports for messages @@ -156,10 +162,17 @@ def poll_c_inports(self): channel_actions.append((self.mgmt_channel.dst_port, lambda: ('stop', None))) - action, channel = selector.select(*channel_actions) + + resp = selector.select(*channel_actions) + if resp is None: + continue + + action, channel = resp + if action == "stop": + self.mgmt_channel.dst_port.recv() return - else: + elif action is not None: action._recv(channel) def poll_c_outports(self): diff --git a/src/lava/magma/runtime/runtime_services/interfaces.py b/src/lava/magma/runtime/runtime_services/interfaces.py index fd35c621c..0eae48bf4 100644 --- a/src/lava/magma/runtime/runtime_services/interfaces.py +++ b/src/lava/magma/runtime/runtime_services/interfaces.py @@ -5,9 +5,9 @@ import typing as ty from abc import ABC, abstractmethod -from lava.magma.compiler.channels.pypychannel import ( - CspRecvPort, - CspSendPort +from lava.magma.runtime.message_infrastructure import ( + RecvPort, + SendPort, ) from lava.magma.core.sync.protocol import AbstractSyncProtocol @@ -18,8 +18,8 @@ def __init__(self, protocol): self.runtime_service_id: ty.Optional[int] = None - self.runtime_to_service: ty.Optional[CspRecvPort] = None - self.service_to_runtime: ty.Optional[CspSendPort] = None + self.runtime_to_service: ty.Optional[RecvPort] = None + self.service_to_runtime: ty.Optional[SendPort] = None self.model_ids: ty.List[int] = [] diff --git a/src/lava/magma/runtime/runtime_services/runtime_service.py b/src/lava/magma/runtime/runtime_services/runtime_service.py index d62cffdbd..9db26abaa 100644 --- a/src/lava/magma/runtime/runtime_services/runtime_service.py +++ b/src/lava/magma/runtime/runtime_services/runtime_service.py @@ -30,11 +30,13 @@ import numpy as np -from lava.magma.compiler.channels.pypychannel import ( - CspSelector, - CspRecvPort, - CspSendPort +from lava.magma.runtime.message_infrastructure import ( + RecvPort, + SendPort, + SupportTempChannel, ) +from lava.magma.runtime.message_infrastructure import Selector + from lava.magma.core.sync.protocol import AbstractSyncProtocol from lava.magma.runtime.mgmt_token_enums import ( enum_to_np, @@ -60,8 +62,8 @@ def __init__( self.log = logging.getLogger(__name__) self.log.setLevel(kwargs.get("loglevel", logging.WARNING)) super(PyRuntimeService, self).__init__(protocol=protocol) - self.service_to_process: ty.Iterable[CspSendPort] = [] - self.process_to_service: ty.Iterable[CspRecvPort] = [] + self.service_to_process: ty.List[SendPort] = [] + self.process_to_service: ty.List[RecvPort] = [] def start(self): """Start the necessary channels to coordinate with runtime and group @@ -93,27 +95,40 @@ def _relay_to_runtime_data_given_model_id(self, model_id: int): """Relays data received from ProcessModel given by model id to the runtime""" process_idx = self.model_ids.index(model_id) - data_recv_port = self.process_to_service[process_idx] - data_relay_port = self.service_to_runtime - num_items = data_recv_port.recv() - data_relay_port.send(num_items) - for _ in range(int(num_items[0])): - value = data_recv_port.recv() - data_relay_port.send(value) + if SupportTempChannel: + addr_recv_port = self.runtime_to_service + addr_relay_port = self.service_to_process[process_idx] + addr_path = addr_recv_port.recv() + addr_relay_port.send(addr_path) + else: + data_recv_port = self.process_to_service[process_idx] + data_relay_port = self.service_to_runtime + num_items = data_recv_port.recv() + data_relay_port.send(num_items) + for _ in range(int(num_items[0])): + value = data_recv_port.recv() + data_relay_port.send(value) def _relay_to_pm_data_given_model_id(self, model_id: int) -> MGMT_RESPONSE: """Relays data received from the runtime to the ProcessModel given by the model id.""" process_idx = self.model_ids.index(model_id) - data_recv_port = self.runtime_to_service - data_relay_port = self.service_to_process[process_idx] - resp_port = self.process_to_service[process_idx] - # Receive and relay number of items - num_items = data_recv_port.recv() - data_relay_port.send(num_items) - # Receive and relay data1, data2, ... - for _ in range(int(num_items[0].item())): - data_relay_port.send(data_recv_port.recv()) + if SupportTempChannel: + addr_recv_port = self.process_to_service[process_idx] + addr_relay_port = self.service_to_runtime + addr_path = addr_recv_port.recv() + addr_relay_port.send(addr_path) + resp_port = self.process_to_service[process_idx] + else: + data_recv_port = self.runtime_to_service + data_relay_port = self.service_to_process[process_idx] + resp_port = self.process_to_service[process_idx] + # Receive and relay number of items + num_items = data_recv_port.recv() + data_relay_port.send(num_items) + # Receive and relay data1, data2, ... + for _ in range(int(num_items[0].item())): + data_relay_port.send(data_recv_port.recv()) rsp = resp_port.recv() return rsp @@ -254,6 +269,47 @@ def _get_pm_resp(self) -> ty.Iterable[MGMT_RESPONSE]: self.req_stop = True return rcv_msgs + def _relay_to_runtime_data_given_model_id(self, model_id: int): + """Relays data received from ProcessModel given by model id to the + runtime""" + process_idx = self.model_ids.index(model_id) + if SupportTempChannel: + addr_recv_port = self.runtime_to_service + addr_relay_port = self.service_to_process[process_idx] + addr_path = addr_recv_port.recv() + addr_relay_port.send(addr_path) + else: + data_recv_port = self.process_to_service[process_idx] + data_relay_port = self.service_to_runtime + num_items = data_recv_port.recv() + data_relay_port.send(num_items) + for _ in range(int(num_items[0])): + value = data_recv_port.recv() + data_relay_port.send(value) + + def _relay_to_pm_data_given_model_id(self, model_id: int) -> MGMT_RESPONSE: + """Relays data received from the runtime to the ProcessModel given by + the model id.""" + process_idx = self.model_ids.index(model_id) + if SupportTempChannel: + addr_recv_port = self.process_to_service[process_idx] + addr_relay_port = self.service_to_runtime + addr_path = addr_recv_port.recv() + addr_relay_port.send(addr_path) + resp_port = self.process_to_service[process_idx] + else: + data_recv_port = self.runtime_to_service + data_relay_port = self.service_to_process[process_idx] + resp_port = self.process_to_service[process_idx] + # Receive and relay number of items + num_items = data_recv_port.recv() + data_relay_port.send(num_items) + # Receive and relay data1, data2, ... + for _ in range(int(num_items[0].item())): + data_relay_port.send(data_recv_port.recv()) + rsp = resp_port.recv() + return rsp + def _relay_pm_ack_given_model_id(self, model_id: int): """Relays ack received from ProcessModel given by model id to the runtime.""" @@ -295,7 +351,7 @@ def run(self): In this case iterate through the phases of the Loihi protocol until the last time step is reached. The runtime is informed after the last time step. The loop ends when receiving the STOP command from the runtime.""" - selector = CspSelector() + selector = Selector() phase = LoihiPhase.HOST channel_actions = [(self.runtime_to_service, lambda: "cmd")] @@ -373,7 +429,6 @@ def run(self): if enum_equal(cmd, MGMT_COMMAND.PAUSE): self.pausing = True self.req_pause = True - # If HOST phase (last time step ended) break the loop if enum_equal(phase, LoihiPhase.HOST): break @@ -384,7 +439,7 @@ def run(self): # Inform the runtime that last time step was reached if is_last_ts: self.service_to_runtime.send(MGMT_RESPONSE.DONE) - else: + elif action is not None: self.service_to_runtime.send(MGMT_RESPONSE.ERROR) @@ -457,13 +512,12 @@ def _handle_stop(self): def run(self): """Retrieves commands from the runtime and relays them to the process models. Also send the acknowledgement back to runtime.""" - selector = CspSelector() + selector = Selector() channel_actions = [(self.runtime_to_service, lambda: "cmd")] while True: - # Probe if there is a new command from the runtime action = selector.select(*channel_actions) - channel_actions = [] if action == "cmd": + channel_actions = [] command = self.runtime_to_service.recv() if enum_equal(command, MGMT_COMMAND.STOP): self._handle_stop() @@ -491,6 +545,7 @@ def run(self): (ptos_recv_port, lambda: "resp") ) elif action == "resp": + channel_actions = [] resps = self._get_pm_resp() done: bool = True for resp in resps: @@ -518,8 +573,9 @@ def run(self): if self._error: self.service_to_runtime.send(MGMT_RESPONSE.ERROR) self.running = False - else: + elif action is not None: self.service_to_runtime.send(MGMT_RESPONSE.ERROR) + self.join() self.running = False raise ValueError(f"Wrong type of channel action : {action}") channel_actions.append((self.runtime_to_service, lambda: "cmd")) diff --git a/src/lava/proc/conv/utils.py b/src/lava/proc/conv/utils.py index 06943c095..26d346581 100644 --- a/src/lava/proc/conv/utils.py +++ b/src/lava/proc/conv/utils.py @@ -258,8 +258,8 @@ def conv_scipy(input_: np.ndarray, weight.shape[0], dilation[0] * (kernel_size[0] - 1) + 1, dilation[1] * (kernel_size[1] - 1) + 1, - weight.shape[-1] - ]) + weight.shape[-1]] + ) dilated_weight[:, ::dilation[0], ::dilation[1], :] = weight input_padded = np.pad( diff --git a/src/lava/proc/dense/models.py b/src/lava/proc/dense/models.py index f9b833976..5a03b875a 100644 --- a/src/lava/proc/dense/models.py +++ b/src/lava/proc/dense/models.py @@ -93,6 +93,7 @@ def run_spk(self): # The a_out sent at each timestep is a buffered value from dendritic # accumulation at timestep t-1. This prevents deadlocking in # networks with recurrent connectivity structures. + self.a_buff = self.a_buff.astype(np.int32) self.a_out.send(self.a_buff) if self.num_message_bits.item() > 0: s_in = self.s_in.recv() @@ -183,6 +184,162 @@ def run_spk(self): a_accum = self.weights[:, s_in].sum(axis=1) self.a_buff = ( + np.left_shift(a_accum, self.weight_exp).astype(self.a_buff.dtype) + if self.weight_exp > 0 + else np.right_shift( + a_accum, -self.weight_exp).astype(self.a_buff.dtype) + ) + + self.recv_traces(s_in) + + +class AbstractPyDelayDenseModel(PyLoihiProcessModel): + """Abstract Conn Process with Dense synaptic connections which incorporates + delays into the Conn Process. + """ + + weights: np.ndarray = None + delays: np.ndarray = None + a_buff: np.ndarray = None + + @staticmethod + def get_del_wgts(weights, delays) -> np.ndarray: + """ + Use self.weights and self.delays to create a matrix where the + weights are separated by delay. Returns 2D matrix of form + (num_flat_output_neurons * max_delay + 1, num_flat_input_neurons) where + del_wgts[ + k * num_flat_output_neurons : (k + 1) * num_flat_output_neurons, : + ] + contains the weights for all connections with a delay equal to k. + This allows for the updating of the activation buffer and updating + weights. + """ + return np.vstack([ + np.where(delays == k, weights, 0) + for k in range(np.max(delays) + 1) + ]) + + def calc_act(self, s_in) -> np.ndarray: + """ + Calculate the activations by performing del_wgts * s_in. This matrix + is then summed across each row to get the activations to the output + neurons for different delays. This activation vector is reshaped to a + matrix of the form + (n_flat_output_neurons * (max_delay + 1), n_flat_output_neurons) + which is then transposed to get the activation matrix. + """ + return np.reshape( + np.sum(self.get_del_wgts(self.weights, + self.delays) * s_in, axis=1), + (np.max(self.delays) + 1, self.weights.shape[0])).T + + def update_act(self, s_in): + """ + Updates the activations for the connection. + Clears first column of a_buff and rolls them to the last column. + Finally, calculates the activations for the current time step and adds + them to a_buff. + This order of operations ensures that delays of 0 correspond to + the next time step. + """ + self.a_buff[:, 0] = 0 + self.a_buff = np.roll(self.a_buff, -1) + self.a_buff += self.calc_act(s_in) + + +@implements(proc=DelayDense, protocol=LoihiProtocol) +@requires(CPU) +@tag("floating_pt") +class PyDelayDenseModelFloat(AbstractPyDelayDenseModel): + """Implementation of Conn Process with Dense synaptic connections in + floating point precision. This short and simple ProcessModel can be used + for quick algorithmic prototyping, without engaging with the nuances of a + fixed point implementation. DelayDense incorporates delays into the Conn + Process. + """ + s_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, bool, precision=1) + a_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, float) + a_buff: np.ndarray = LavaPyType(np.ndarray, float) + # weights is a 2D matrix of form (num_flat_output_neurons, + # num_flat_input_neurons) in C-order (row major). + weights: np.ndarray = LavaPyType(np.ndarray, float) + # delays is a 2D matrix of form (num_flat_output_neurons, + # num_flat_input_neurons) in C-order (row major). + delays: np.ndarray = LavaPyType(np.ndarray, int) + num_message_bits: np.ndarray = LavaPyType(np.ndarray, int, precision=5) + + def run_spk(self): + # The a_out sent on a each timestep is a buffered value from dendritic + # accumulation at timestep t-1. This prevents deadlocking in + # networks with recurrent connectivity structures. + self.a_out.send(self.a_buff[:, 0]) + if self.num_message_bits.item() > 0: + s_in = self.s_in.recv() + else: + s_in = self.s_in.recv().astype(bool) + self.update_act(s_in) + + +@implements(proc=DelayDense, protocol=LoihiProtocol) +@requires(CPU) +@tag("bit_accurate_loihi", "fixed_pt") +class PyDelayDenseModelBitAcc(AbstractPyDelayDenseModel): + """Implementation of Conn Process with Dense synaptic connections that is + bit-accurate with Loihi's hardware implementation of Dense, which means, + it mimics Loihi behaviour bit-by-bit. DelayDense incorporates delays into + the Conn Process. Loihi 2 has a maximum of 6 bits for delays, meaning a + spike can be delayed by 0 to 63 time steps.""" + + s_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, bool, precision=1) + a_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, np.int32, precision=16) + a_buff: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=16) + # weights is a 2D matrix of form (num_flat_output_neurons, + # num_flat_input_neurons) in C-order (row major). + weights: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=8) + delays: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=6) + num_message_bits: np.ndarray = LavaPyType(np.ndarray, int, precision=5) + + def __init__(self, proc_params): + super().__init__(proc_params) + # Flag to determine whether weights have already been scaled. + self.weights_set = False + + def run_spk(self): + self.weight_exp: int = self.proc_params.get("weight_exp", 0) + + # Since this Process has no learning, weights are assumed to be static + # and only require scaling on the first timestep of run_spk(). + if not self.weights_set: + num_weight_bits: int = self.proc_params.get("num_weight_bits", 8) + sign_mode: SignMode = self.proc_params.get("sign_mode") \ + or determine_sign_mode(self.weights) + + self.weights = clip_weights(self.weights, sign_mode, num_bits=8) + self.weights = truncate_weights(self.weights, + sign_mode, + num_weight_bits) + self.weights_set = True + + # Check if delays are within Loihi 2 constraints + if np.max(self.delays) > 63: + raise ValueError("DelayDense Process 'delays' expects values " + f"between 0 and 63 for Loihi, got " + f"{self.delays}.") + + # The a_out sent at each timestep is a buffered value from dendritic + # accumulation at timestep t-1. This prevents deadlocking in + # networks with recurrent connectivity structures. + self.a_out.send(self.a_buff[:, 0]) + if self.num_message_bits.item() > 0: + s_in = self.s_in.recv() + else: + s_in = self.s_in.recv().astype(bool) + + a_accum = self.calc_act(s_in) + self.a_buff[:, 0] = 0 + self.a_buff = np.roll(self.a_buff, -1) + self.a_buff += ( np.left_shift(a_accum, self.weight_exp) if self.weight_exp > 0 else np.right_shift(a_accum, -self.weight_exp) @@ -191,6 +348,7 @@ def run_spk(self): self.recv_traces(s_in) +# pylint: disable=E0102 class AbstractPyDelayDenseModel(PyLoihiProcessModel): """Abstract Conn Process with Dense synaptic connections which incorporates delays into the Conn Process. @@ -257,6 +415,7 @@ def update_act(self, s_in): @implements(proc=DelayDense, protocol=LoihiProtocol) @requires(CPU) @tag("floating_pt") +# pylint: disable=E0102 class PyDelayDenseModelFloat(AbstractPyDelayDenseModel): """Implementation of Conn Process with Dense synaptic connections in floating point precision. This short and simple ProcessModel can be used @@ -290,6 +449,7 @@ def run_spk(self): @implements(proc=DelayDense, protocol=LoihiProtocol) @requires(CPU) @tag("bit_accurate_loihi", "fixed_pt") +# pylint: disable=E0102 class PyDelayDenseModelBitAcc(AbstractPyDelayDenseModel): """Implementation of Conn Process with Dense synaptic connections that is bit-accurate with Loihi's hardware implementation of Dense, which means, diff --git a/src/lava/proc/io/dataloader.py b/src/lava/proc/io/dataloader.py index 3fde34e5e..5d7e641b9 100644 --- a/src/lava/proc/io/dataloader.py +++ b/src/lava/proc/io/dataloader.py @@ -167,7 +167,9 @@ def __init__( super().__init__(gt_shape, dataset, interval, offset) data_shape = data.shape[:-1] + (interval,) - self.data = Var(shape=data_shape, init=np.zeros(data_shape)) + self.data = Var( + shape=data_shape, + init=np.zeros(data_shape, dtype=data.dtype)) self.s_out = OutPort(shape=data.shape[:-1]) # last dimension is time diff --git a/src/lava/proc/io/extractor.py b/src/lava/proc/io/extractor.py index 156370c6a..c0e53bec9 100644 --- a/src/lava/proc/io/extractor.py +++ b/src/lava/proc/io/extractor.py @@ -13,9 +13,21 @@ from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol from lava.magma.core.model.py.type import LavaPyType from lava.magma.core.model.py.ports import PyInPort -from lava.magma.compiler.channels.pypychannel import PyPyChannel -from lava.magma.runtime.message_infrastructure.multiprocessing import \ - MultiProcessing +from lava.magma.runtime.message_infrastructure import PURE_PYTHON_VERSION +if PURE_PYTHON_VERSION: + from lava.magma.runtime.message_infrastructure.py_multiprocessing \ + import MultiProcessing + from lava.magma.runtime.message_infrastructure.pypychannel \ + import PyPyChannel as Channel +else: + from lava.magma.runtime.message_infrastructure.multiprocessing \ + import MultiProcessing + from lava.magma.runtime.message_infrastructure import Channel as Channel + from lava.magma.runtime.message_infrastructure \ + .MessageInfrastructurePywrapper import ChannelType + from lava.magma.runtime.message_infrastructure \ + import ChannelQueueSize + from lava.proc.io import utils @@ -63,15 +75,15 @@ def __init__(self, self._shape = shape self._multi_processing = MultiProcessing() - self._multi_processing.start() + self._multi_processing.init() # Stands for ProcessModel to Process - pm_to_p = PyPyChannel(message_infrastructure=self._multi_processing, - src_name="src", - dst_name="dst", - shape=self._shape, - dtype=float, - size=buffer_size) + pm_to_p = Channel(message_infrastructure=self._multi_processing, + src_name="src", + dst_name="dst", + shape=self._shape, + dtype=float, + size=buffer_size) self._pm_to_p_dst_port = pm_to_p.dst_port self._pm_to_p_dst_port.start() diff --git a/src/lava/proc/io/injector.py b/src/lava/proc/io/injector.py index e0c4207e5..47d2167c5 100644 --- a/src/lava/proc/io/injector.py +++ b/src/lava/proc/io/injector.py @@ -13,9 +13,21 @@ from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol from lava.magma.core.model.py.type import LavaPyType from lava.magma.core.model.py.ports import PyOutPort -from lava.magma.runtime.message_infrastructure.multiprocessing import \ - MultiProcessing -from lava.magma.compiler.channels.pypychannel import PyPyChannel +from lava.magma.runtime.message_infrastructure import PURE_PYTHON_VERSION +if PURE_PYTHON_VERSION: + from lava.magma.runtime.message_infrastructure.py_multiprocessing \ + import MultiProcessing + from lava.magma.runtime.message_infrastructure.pypychannel \ + import PyPyChannel as Channel +else: + from lava.magma.runtime.message_infrastructure.multiprocessing \ + import MultiProcessing + from lava.magma.runtime.message_infrastructure import Channel as Channel + from lava.magma.runtime.message_infrastructure \ + .MessageInfrastructurePywrapper import ChannelType + from lava.magma.runtime.message_infrastructure \ + import ChannelQueueSize + from lava.proc.io import utils @@ -62,15 +74,15 @@ def __init__(self, utils.validate_channel_config(channel_config) self._multi_processing = MultiProcessing() - self._multi_processing.start() + self._multi_processing.init() # Stands for Process to ProcessModel - p_to_pm = PyPyChannel(message_infrastructure=self._multi_processing, - src_name="src", - dst_name="dst", - shape=shape, - dtype=float, - size=buffer_size) + p_to_pm = Channel(message_infrastructure=self._multi_processing, + src_name="src", + dst_name="dst", + shape=shape, + dtype=float, + size=buffer_size) self._p_to_pm_src_port = p_to_pm.src_port self._p_to_pm_src_port.start() diff --git a/src/lava/proc/io/utils.py b/src/lava/proc/io/utils.py index 7f8bc0ec7..309761317 100644 --- a/src/lava/proc/io/utils.py +++ b/src/lava/proc/io/utils.py @@ -8,7 +8,8 @@ import numpy as np import warnings -from lava.magma.compiler.channels.pypychannel import CspSendPort, CspRecvPort +from lava.magma.runtime.message_infrastructure \ + import SendPort as CspSendPort, RecvPort as CspRecvPort class SendFull(IntEnum): diff --git a/src/lava/proc/sparse/models.py b/src/lava/proc/sparse/models.py index d553868e2..bf89ebfc5 100644 --- a/src/lava/proc/sparse/models.py +++ b/src/lava/proc/sparse/models.py @@ -75,6 +75,7 @@ def __init__(self, proc_params): self.weights_set = False def run_spk(self): + # pylint: disable=W0201 self.weight_exp: int = self.proc_params.get("weight_exp", 0) # Since this Process has no learning, weights are assumed to be static diff --git a/tests/lava/magma/compiler/builders/test_builder.py b/tests/lava/magma/compiler/builders/test_builder.py index 160ab67b1..3df614b01 100644 --- a/tests/lava/magma/compiler/builders/test_builder.py +++ b/tests/lava/magma/compiler/builders/test_builder.py @@ -8,13 +8,6 @@ from lava.magma.compiler.builders.channel_builder import ChannelBuilderMp from lava.magma.compiler.builders.py_builder import PyProcessBuilder -from lava.magma.compiler.channels.interfaces import Channel, ChannelType, \ - AbstractCspPort -from lava.magma.compiler.channels.pypychannel import ( - PyPyChannel, - CspSendPort, - CspRecvPort, -) from lava.magma.compiler.utils import VarInitializer, PortInitializer, \ VarPortInitializer from lava.magma.core.decorator import implements, requires @@ -27,18 +20,26 @@ from lava.magma.core.process.process import AbstractProcess from lava.magma.core.process.variable import Var from lava.magma.core.resources import CPU +from lava.magma.runtime.message_infrastructure import ( + create_channel, + Channel, + AbstractTransferPort, +) +from lava.magma.runtime.message_infrastructure.interfaces import ChannelType from lava.magma.runtime.message_infrastructure.shared_memory_manager import ( SharedMemoryManager, ) -from lava.magma.compiler.channels.watchdog import NoOPWatchdogManager +from lava.magma.runtime.message_infrastructure.watchdog import \ + NoOPWatchdogManager class MockMessageInterface: def __init__(self, smm): self.smm = smm - def channel_class(self, channel_type: ChannelType) -> ty.Type: - return PyPyChannel + def channel(self, channel_type: ChannelType, src_name, dst_name, + shape, dtype, size) -> Channel: + return create_channel(self, src_name, dst_name, shape, dtype, size) class TestChannelBuilder(unittest.TestCase): @@ -59,23 +60,21 @@ def test_channel_builder(self): smm.start() mock = MockMessageInterface(smm) - channel: Channel = channel_builder.build(mock, - NoOPWatchdogManager()) - assert isinstance(channel, PyPyChannel) - assert isinstance(channel.src_port, CspSendPort) - assert isinstance(channel.dst_port, CspRecvPort) - - channel.src_port.start() - channel.dst_port.start() - - expected_data = np.array([[1, 2]]) - channel.src_port.send(data=expected_data) - data = channel.dst_port.recv() - assert np.array_equal(data, expected_data) - - channel.src_port.join() - channel.dst_port.join() - + channel: Channel = channel_builder.build(mock) + assert isinstance(channel, Channel) + assert isinstance(channel.src_port, AbstractTransferPort) + assert isinstance(channel.dst_port, AbstractTransferPort) + try: + channel.src_port.start() + channel.dst_port.start() + + expected_data = np.array([[1, 2]], dtype=np.int32) + channel.src_port.send(expected_data) + data = channel.dst_port.recv() + assert np.array_equal(data, expected_data) + finally: + channel.src_port.join() + channel.dst_port.join() finally: smm.shutdown() @@ -119,7 +118,7 @@ def run(self): # A fake CspPort just to test ProcBuilder -class FakeCspPort(AbstractCspPort): +class FakeCspPort: def __init__(self, name="mock"): self._name = name diff --git a/tests/lava/magma/compiler/channels/test_pypychannel.py b/tests/lava/magma/compiler/channels/test_channel.py similarity index 70% rename from tests/lava/magma/compiler/channels/test_pypychannel.py rename to tests/lava/magma/compiler/channels/test_channel.py index 62871d01a..6b3320ba8 100644 --- a/tests/lava/magma/compiler/channels/test_pypychannel.py +++ b/tests/lava/magma/compiler/channels/test_channel.py @@ -4,9 +4,12 @@ import numpy as np import unittest +import time from multiprocessing import Process - -from lava.magma.compiler.channels.pypychannel import PyPyChannel +from lava.magma.runtime.message_infrastructure import ( + create_channel, + Channel, +) from lava.magma.runtime.message_infrastructure.shared_memory_manager import ( SharedMemoryManager ) @@ -17,17 +20,15 @@ def __init__(self, smm): self.smm = smm -def get_channel(smm, data, size, name="test_channel") -> PyPyChannel: +def get_channel(smm, data, name="test_channel") -> Channel: mock = MockInterface(smm) - channel = PyPyChannel( + return create_channel( message_infrastructure=mock, - src_name=name, - dst_name=name, + src_name=name + "src", + dst_name=name + "dst", shape=data.shape, dtype=data.dtype, - size=size - ) - return channel + size=data.size) class TestPyPyChannelSingleProcess(unittest.TestCase): @@ -35,16 +36,18 @@ def test_send_recv_single_process(self): smm = SharedMemoryManager() try: smm.start() - data = np.ones((2, 2, 2)) - channel = get_channel(smm, data, size=2) - - channel.src_port.start() - channel.dst_port.start() - - channel.src_port.send(data=data) - result = channel.dst_port.recv() - assert np.array_equal(result, data) + channel = get_channel(smm, data) + try: + channel.src_port.start() + channel.dst_port.start() + + channel.src_port.send(data) + result = channel.dst_port.recv() + assert np.array_equal(result, data) + finally: + channel.src_port.join() + channel.dst_port.join() finally: smm.shutdown() @@ -52,16 +55,18 @@ def test_send_recv_single_process_2d_data(self): smm = SharedMemoryManager() try: smm.start() - data = np.random.randint(100, size=(100, 100), dtype=np.int32) - channel = get_channel(smm, data, size=100) - - channel.src_port.start() - channel.dst_port.start() - - channel.src_port.send(data=data) - result = channel.dst_port.recv() - assert np.array_equal(result, data) + channel = get_channel(smm, data) + try: + channel.src_port.start() + channel.dst_port.start() + + channel.src_port.send(data) + result = channel.dst_port.recv() + assert np.array_equal(result, data) + finally: + channel.src_port.join() + channel.dst_port.join() finally: smm.shutdown() @@ -69,16 +74,18 @@ def test_send_recv_single_process_1d_data(self): smm = SharedMemoryManager() try: smm.start() - data = np.random.randint(1000, size=100, dtype=np.int16) - channel = get_channel(smm, data, size=10) - - channel.src_port.start() - channel.dst_port.start() - - channel.src_port.send(data=data) - result = channel.dst_port.recv() - assert np.array_equal(result, data) + channel = get_channel(smm, data) + try: + channel.src_port.start() + channel.dst_port.start() + + channel.src_port.send(data) + result = channel.dst_port.recv() + assert np.array_equal(result, data) + finally: + channel.src_port.join() + channel.dst_port.join() finally: smm.shutdown() @@ -93,7 +100,11 @@ def __init__(self, ports=None, **kwargs): def run(self): for c in self._ports: c.start() + # need to wait all port started. + time.sleep(0.01) super().run() + for c in self._ports: + c.join() def source(shape, port): @@ -128,18 +139,18 @@ def buffer(shape, dst_port, src_port): class TestPyPyChannelMultiProcess(unittest.TestCase): + def test_send_recv_relay(self): smm = SharedMemoryManager() try: smm.start() data = np.ones((2, 2)) channel_source_to_buffer = get_channel( - smm, data, size=2, name="channel_source_to_buffer" + smm, data, name="channel_source_to_buffer" ) channel_buffer_to_sink = get_channel( - smm, data, size=2, name="channel_buffer_to_sink" + smm, data, name="channel_buffer_to_sink" ) - jobs = [ DummyProcess( ports=(channel_source_to_buffer.src_port,), diff --git a/tests/lava/magma/compiler/subcompilers/test_channel_builders_factory.py b/tests/lava/magma/compiler/subcompilers/test_channel_builders_factory.py index 41dded64a..2c535837e 100644 --- a/tests/lava/magma/compiler/subcompilers/test_channel_builders_factory.py +++ b/tests/lava/magma/compiler/subcompilers/test_channel_builders_factory.py @@ -9,7 +9,7 @@ from lava.magma.compiler.builders.channel_builder import ChannelBuilderMp from lava.magma.compiler.channel_map import ChannelMap, Payload, PortPair -from lava.magma.compiler.channels.interfaces import ChannelType +from lava.magma.runtime.message_infrastructure.interfaces import ChannelType from lava.magma.compiler.compiler_graphs import ProcGroupDiGraphs from lava.magma.compiler.subcompilers.channel_builders_factory import \ ChannelBuildersFactory @@ -18,7 +18,8 @@ from lava.magma.core.model.interfaces import AbstractPortImplementation from lava.magma.core.model.py.model import AbstractPyProcessModel from lava.magma.core.model.py.ports import (PyInPort, PyOutPort, PyRefPort, - PyVarPort) + PyVarPort, + AbstractPortImplementation) from lava.magma.core.model.py.type import LavaPyType from lava.magma.core.model.sub.model import AbstractSubProcessModel from lava.magma.core.process.ports.ports import (AbstractPort, InPort, OutPort, diff --git a/tests/lava/magma/compiler/test_channel_builder.py b/tests/lava/magma/compiler/test_channel_builder.py index 92da9b856..712e53900 100644 --- a/tests/lava/magma/compiler/test_channel_builder.py +++ b/tests/lava/magma/compiler/test_channel_builder.py @@ -6,27 +6,31 @@ import unittest import numpy as np - +from multiprocessing.managers import SharedMemoryManager from lava.magma.compiler.builders.channel_builder import ChannelBuilderMp -from lava.magma.compiler.channels.interfaces import Channel, ChannelType from lava.magma.compiler.utils import PortInitializer -from lava.magma.compiler.channels.pypychannel import ( - PyPyChannel, - CspSendPort, - CspRecvPort, + +from lava.magma.runtime.message_infrastructure import ( + Channel, + SendPort, + RecvPort, + create_channel ) +from lava.magma.runtime.message_infrastructure.interfaces import ChannelType from lava.magma.runtime.message_infrastructure.shared_memory_manager import ( SharedMemoryManager ) -from lava.magma.compiler.channels.watchdog import NoOPWatchdogManager +from lava.magma.runtime.message_infrastructure.watchdog import \ + NoOPWatchdogManager class MockMessageInterface: def __init__(self, smm): self.smm = smm - def channel_class(self, channel_type: ChannelType) -> ty.Type: - return PyPyChannel + def channel(self, channel_type: ChannelType, src_name, dst_name, + shape, dtype, size) -> Channel: + return create_channel(self, src_name, dst_name, shape, dtype, size) class TestChannelBuilder(unittest.TestCase): @@ -44,20 +48,17 @@ def test_channel_builder(self): src_process=None, dst_process=None, ) - smm.start() mock = MockMessageInterface(smm) - channel: Channel = channel_builder.build(mock, - NoOPWatchdogManager()) - assert isinstance(channel, PyPyChannel) - assert isinstance(channel.src_port, CspSendPort) - assert isinstance(channel.dst_port, CspRecvPort) + channel: Channel = channel_builder.build(mock) + self.assertIsInstance(channel.src_port, SendPort) + self.assertIsInstance(channel.dst_port, RecvPort) channel.src_port.start() channel.dst_port.start() - expected_data = np.array([[1, 2]]) - channel.src_port.send(data=expected_data) + expected_data = np.array([[1, 2]], dtype=np.int32) + channel.src_port.send(expected_data) data = channel.dst_port.recv() assert np.array_equal(data, expected_data) diff --git a/tests/lava/magma/core/model/py/test_ports.py b/tests/lava/magma/core/model/py/test_ports.py index 31bf2c325..bc8fcc200 100644 --- a/tests/lava/magma/core/model/py/test_ports.py +++ b/tests/lava/magma/core/model/py/test_ports.py @@ -7,12 +7,13 @@ import numpy as np import typing as ty import functools as ft +from multiprocessing.managers import SharedMemoryManager +from lava.magma.runtime.message_infrastructure import ( + Channel, + create_channel, + AbstractTransferPort, +) -from lava.magma.compiler.channels.interfaces import ( - AbstractCspPort, - AbstractCspSendPort, - AbstractCspRecvPort) -from lava.magma.compiler.channels.pypychannel import PyPyChannel from lava.magma.core.model.py.ports import ( PyInPort, PyInPortVectorDense, @@ -30,16 +31,15 @@ def __init__(self, smm): self.smm = smm -def get_channel(smm, data, size, name="test_channel") -> PyPyChannel: +def get_channel(smm, data, name="test_channel") -> Channel: mock = MockInterface(smm) - return PyPyChannel( + return create_channel( message_infrastructure=mock, - src_name=name, - dst_name=name, + src_name=name + "src", + dst_name=name + "dst", shape=data.shape, dtype=data.dtype, - size=size - ) + size=data.size) class TestPyPorts(unittest.TestCase): @@ -50,16 +50,15 @@ def probe_test_routine(self, cls): try: smm.start() - data = np.ones((4, 4)) - channel_1 = get_channel(smm, data, data.size) - send_csp_port_1: AbstractCspSendPort = channel_1.src_port - recv_csp_port_1: AbstractCspRecvPort = channel_1.dst_port + channel_1 = get_channel(smm, data) + send_csp_port_1: AbstractTransferPort = channel_1.src_port + recv_csp_port_1: AbstractTransferPort = channel_1.dst_port - channel_2 = get_channel(smm, data, data.size) - send_csp_port_2: AbstractCspSendPort = channel_2.src_port - recv_csp_port_2: AbstractCspRecvPort = channel_2.dst_port + channel_2 = get_channel(smm, data) + send_csp_port_2: AbstractTransferPort = channel_2.src_port + recv_csp_port_2: AbstractTransferPort = channel_2.dst_port # Create two different PyOutPort send_py_port_1: PyOutPort = \ @@ -73,29 +72,36 @@ def probe_test_routine(self, cls): cls([recv_csp_port_1, recv_csp_port_2], None, data.shape, data.dtype) - recv_py_port.start() - send_py_port_1.start() - send_py_port_2.start() - - # Send data through first PyOutPort - send_py_port_1.send(data) - # Send data through second PyOutPort - send_py_port_2.send(data) - # Sleep to let message reach the PyInPort - time.sleep(0.001) - # Probe PyInPort - probe_value = recv_py_port.probe() - - # probe_value should be True if message reached the PyInPort - self.assertTrue(probe_value) - - # Get data that reached PyInPort to empty buffer - _ = recv_py_port.recv() - # Probe PyInPort - probe_value = recv_py_port.probe() - - # probe_value should be False since PyInPort's buffer was emptied - self.assertFalse(probe_value) + try: + recv_py_port.start() + send_py_port_1.start() + send_py_port_2.start() + + # Send data through first PyOutPort + send_py_port_1.send(data) + # Send data through second PyOutPort + send_py_port_2.send(data) + # Sleep to let message reach the PyInPort + time.sleep(0.01) + # Probe PyInPort + probe_value = recv_py_port.probe() + + # probe_value should be True if message reached the PyInPort + self.assertTrue(probe_value) + + # Get data that reached PyInPort to empty buffer + _ = recv_py_port.recv() + # Probe PyInPort + probe_value = recv_py_port.probe() + + # probe_value should be False since + # PyInPort's buffer was emptied + self.assertFalse(probe_value) + finally: + send_csp_port_1.join() + recv_csp_port_1.join() + send_py_port_1.join() + recv_csp_port_2.join() finally: smm.shutdown() @@ -107,7 +113,8 @@ def test_py_in_port_probe(self): self.probe_test_routine(cls) -class MockCspPort(AbstractCspPort): +class MockCspPort: + @property def name(self) -> str: return "mock_csp_port" @@ -167,7 +174,7 @@ def test_transform(self) -> None: def test_transformation_defaults_to_identity(self) -> None: """Tests whether the transformation defaults to the identity transformation when no transformation functions are specified.""" - del(self.transform_funcs["id0"]) + del (self.transform_funcs["id0"]) vpt = VirtualPortTransformer(self.csp_ports, self.transform_funcs) data = np.array(5) diff --git a/tests/lava/magma/core/process/test_process.py b/tests/lava/magma/core/process/test_process.py index 6c690ba9e..9e56a2841 100644 --- a/tests/lava/magma/core/process/test_process.py +++ b/tests/lava/magma/core/process/test_process.py @@ -294,6 +294,7 @@ def test_compile(self) -> None: self.assertIsInstance(e, Executable) self.assertEqual(len(e.proc_builders), 1) + @unittest.skip("This case cannot run becase the MinimalPyProcessModel") def test_create_runtime(self) -> None: """Tests the create_runtime method.""" p = MinimalProcess() diff --git a/tests/lava/magma/runtime/message_infrastructure/test_all_delivery.py b/tests/lava/magma/runtime/message_infrastructure/test_all_delivery.py new file mode 100644 index 000000000..1ca2b38d1 --- /dev/null +++ b/tests/lava/magma/runtime/message_infrastructure/test_all_delivery.py @@ -0,0 +1,681 @@ +# Copyright (C) 2021 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + +import numpy as np +import unittest +from functools import partial +import time +from datetime import datetime +from multiprocessing import shared_memory +from multiprocessing import Semaphore +from multiprocessing import Process +from lava.magma.runtime.message_infrastructure import ( + PURE_PYTHON_VERSION, + Channel, + SupportGRPCChannel, + SupportFastDDSChannel, + SupportCycloneDDSChannel +) + + +class PyChannel: + + def __init__(self, dtype, size, nbytes, name, *_) -> None: + self.shm_ = Shm(dtype, size, nbytes, name) + self.src_port = Port(self.shm_) + self.dst_port = Port(self.shm_) + self.dtype_ = dtype + self.shm_.start() + + def start(self): + pass + + def join(self): + pass + + +class Port: + + def __init__(self, shm) -> None: + self.shm_ = shm + + def send(self, data): + return self.shm_.push(data) + + def recv(self): + return self.shm_.pop() + + def start(self): + pass + + def join(self): + pass + + +class Shm: + def __init__(self, dtype, size, nbytes, name) -> None: + self.shm_ = shared_memory.SharedMemory(name=name, + create=True, size=nbytes * size) + self.nbytes_ = nbytes + self.size_ = size + self.sem_ack_ = Semaphore(size) + self.sem_req_ = Semaphore(0) + self.sem_ = Semaphore(0) + self.read_ = 0 + self.write_ = 0 + self.type_ = dtype + + def push(self, data): + self.sem_ack_.acquire() + self.sem_.acquire() + self.shm_.buf[self.write_ * self.nbytes_: + ((self.write_ + 1) * self.nbytes_)] = bytearray(data) + self.write_ = (self.write_ + 1) % self.size_ + self.type_ = data.dtype + self.sem_.release() + self.sem_req_.release() + + def pop(self): + self.sem_req_.acquire() + self.sem_.acquire() + result = bytearray(self.shm_.buf[self.read_ * self.nbytes_: + ((self.read_ + 1) * self.nbytes_)]) + self.read_ = (self.read_ + 1) % self.size_ + self.sem_.release() + self.sem_ack_.release() + return np.frombuffer(result, self.type_) + + def start(self): + self.sem_.release() + + def __del__(self): + self.shm_.close() + self.shm_.unlink() + + +class Builder: + def build(self, i): + pass + + +def bound_target_a1(loop, mp_to_a1, a1_to_a2, + a2_to_a1, a1_to_mp, builder): + from_mp = mp_to_a1.dst_port + from_mp.start() + to_a2 = a1_to_a2.src_port + to_a2.start() + from_a2 = a2_to_a1.dst_port + from_a2.start() + to_mp = a1_to_mp.src_port + to_mp.start() + while loop > 0: + loop = loop - 1 + data = from_mp.recv() + data[0] = data[0] + 1 + to_a2.send(data) + data = from_a2.recv() + data[0] = data[0] + 1 + to_mp.send(data) + + from_mp.join() + to_a2.join() + from_a2.join() + to_mp.join() + + +def bound_target_a2(loop, a1_to_a2, a2_to_a1, builder): + from_a1 = a1_to_a2.dst_port + from_a1.start() + to_a1 = a2_to_a1.src_port + to_a1.start() + while loop > 0: + loop = loop - 1 + data = from_a1.recv() + data[0] = data[0] + 1 + to_a1.send(data) + + from_a1.join() + to_a1.join() + + +def prepare_data(): + arr1 = np.array([1] * 9990) + arr2 = np.array([1, 2, 3, 4, 5, + 6, 7, 8, 9, 0]) + return np.concatenate((arr2, arr1)) + + +class TestAllDelivery(unittest.TestCase): + + def __init__(self, methodName: str = ...) -> None: + super().__init__(methodName) + self.loop_ = 1000 + + @unittest.skipIf(PURE_PYTHON_VERSION, "cpp msg lib test") + def test_cpp_shm_loop_with_cpp_multiprocess(self): + from lava.magma.runtime.message_infrastructure \ + .MessageInfrastructurePywrapper import ChannelType + from lava.magma.runtime.message_infrastructure \ + .multiprocessing \ + import MultiProcessing + + loop = self.loop_ + mp = MultiProcessing() + mp.start() + predata = prepare_data() + queue_size = 1 + nbytes = np.prod(predata.shape) * predata.dtype.itemsize + mp_to_a1 = Channel( + ChannelType.SHMEMCHANNEL, + queue_size, + nbytes, + "mp_to_a1", + "mp_to_a1", + (2, 2), + np.int32) + a1_to_a2 = Channel( + ChannelType.SHMEMCHANNEL, + queue_size, + nbytes, + "a1_to_a2", + "a1_to_a2", + (2, 2), + np.int32) + a2_to_a1 = Channel( + ChannelType.SHMEMCHANNEL, + queue_size, + nbytes, + "a2_to_a1", + "a2_to_a1", + (2, 2), + np.int32) + a1_to_mp = Channel( + ChannelType.SHMEMCHANNEL, + queue_size, + nbytes, + "a1_to_mp", + "a1_to_mp", + (2, 2), + np.int32) + + target_a1 = partial(bound_target_a1, loop, mp_to_a1, + a1_to_a2, a2_to_a1, a1_to_mp) + target_a2 = partial(bound_target_a2, loop, a1_to_a2, a2_to_a1) + + builder = Builder() + + mp.build_actor(target_a1, builder) # actor1 + mp.build_actor(target_a2, builder) # actor2 + + to_a1 = mp_to_a1.src_port + from_a1 = a1_to_mp.dst_port + + to_a1.start() + from_a1.start() + + expect_result = np.copy(predata) + expect_result[0] = (1 + 3 * loop) + loop_start = datetime.now() + while loop > 0: + loop = loop - 1 + to_a1.send(predata) + predata = from_a1.recv() + loop_end = datetime.now() + print("cpp_shm_loop_with_cpp_multiprocess result = ", predata[0]) + if not np.array_equal(expect_result, predata): + print("expect: ", expect_result) + print("result: ", predata) + raise AssertionError() + + to_a1.join() + from_a1.join() + mp.stop() + mp.cleanup(True) + print("cpp_shm_loop_with_cpp_multiprocess timedelta =", + loop_end - loop_start) + + @unittest.skipIf(PURE_PYTHON_VERSION, "cpp msg lib test") + def test_cpp_skt_loop_with_cpp_multiprocess(self): + from lava.magma.runtime.message_infrastructure \ + .MessageInfrastructurePywrapper import ChannelType + from lava.magma.runtime.message_infrastructure \ + .multiprocessing \ + import MultiProcessing + + loop = self.loop_ + mp = MultiProcessing() + mp.start() + predata = prepare_data() + queue_size = 2 + nbytes = np.prod(predata.shape) * predata.dtype.itemsize + mp_to_a1 = Channel( + ChannelType.SOCKETCHANNEL, + queue_size, + nbytes, + "mp_to_a1", + "mp_to_a1", + (2, 2), + np.int32) + a1_to_a2 = Channel( + ChannelType.SOCKETCHANNEL, + queue_size, + nbytes, + "a1_to_a2", + "a1_to_a2", + (2, 2), + np.int32) + a2_to_a1 = Channel( + ChannelType.SOCKETCHANNEL, + queue_size, + nbytes, + "a2_to_a1", + "a2_to_a1", + (2, 2), + np.int32) + a1_to_mp = Channel( + ChannelType.SOCKETCHANNEL, + queue_size, + nbytes, + "a1_to_mp", + "a1_to_mp", + (2, 2), + np.int32) + + target_a1 = partial(bound_target_a1, loop, mp_to_a1, + a1_to_a2, a2_to_a1, a1_to_mp) + target_a2 = partial(bound_target_a2, loop, a1_to_a2, a2_to_a1) + + builder = Builder() + + mp.build_actor(target_a1, builder) # actor1 + mp.build_actor(target_a2, builder) # actor2 + + to_a1 = mp_to_a1.src_port + from_a1 = a1_to_mp.dst_port + + to_a1.start() + from_a1.start() + + expect_result = np.copy(predata) + expect_result[0] = (1 + 3 * loop) + loop_start = datetime.now() + while loop > 0: + loop = loop - 1 + to_a1.send(predata) + predata = from_a1.recv() + loop_end = datetime.now() + print("cpp_skt_loop_with_cpp_multiprocess result = ", predata[0]) + if not np.array_equal(expect_result, predata): + print("expect: ", expect_result) + print("result: ", predata) + raise AssertionError() + + to_a1.join() + from_a1.join() + mp.stop() + mp.cleanup(True) + print("cpp_skt_loop_with_cpp_multiprocess timedelta =", + loop_end - loop_start) + + @unittest.skipIf(PURE_PYTHON_VERSION, "cpp msg lib test") + def test_py_shm_loop_with_cpp_multiprocess(self): + from lava.magma.runtime.message_infrastructure \ + .multiprocessing \ + import MultiProcessing + + loop = self.loop_ + mp = MultiProcessing() + mp.start() + + predata = prepare_data() + queue_size = 1 + nbytes = np.prod(predata.shape) * predata.dtype.itemsize + mp_to_a1 = PyChannel( + predata.dtype, + queue_size, + nbytes, + "mp_to_a1", + "mp_to_a1") + a1_to_a2 = PyChannel( + predata.dtype, + queue_size, + nbytes, + "a1_to_a2", + "a1_to_a2") + a2_to_a1 = PyChannel( + predata.dtype, + queue_size, + nbytes, + "a2_to_a1", + "a2_to_a1") + a1_to_mp = PyChannel( + predata.dtype, + queue_size, + nbytes, + "a1_to_mp", + "a1_to_mp") + + builder = Builder() + + target_a1 = partial(bound_target_a1, loop, mp_to_a1, + a1_to_a2, a2_to_a1, a1_to_mp) + target_a2 = partial(bound_target_a2, loop, a1_to_a2, a2_to_a1) + + mp.build_actor(target_a1, builder) # actor1 + mp.build_actor(target_a2, builder) # actor2 + + to_a1 = mp_to_a1.src_port + from_a1 = a1_to_mp.dst_port + + to_a1.start() + from_a1.start() + + expect_result = np.copy(predata) + expect_result[0] = (1 + 3 * loop) + + loop_start = datetime.now() + while loop > 0: + loop = loop - 1 + to_a1.send(predata) + predata = from_a1.recv() + loop_end = datetime.now() + print("py_shm_loop_with_cpp_multiprocess result = ", predata[0]) + if not np.array_equal(expect_result, predata): + print("expect: ", expect_result) + print("result: ", predata) + raise AssertionError() + + to_a1.join() + from_a1.join() + mp.stop() + mp.cleanup(True) + print("py_shm_loop_with_cpp_multiprocess timedelta =", + loop_end - loop_start) + + @unittest.skipIf(PURE_PYTHON_VERSION, "cpp msg lib test") + def test_py_shm_loop_with_py_multiprocess(self): + loop = self.loop_ + predata = prepare_data() + queue_size = 1 + nbytes = np.prod(predata.shape) * predata.dtype.itemsize + mp_to_a1 = PyChannel( + predata.dtype, + queue_size, + nbytes, + "mp_to_a1", + "mp_to_a1") + a1_to_a2 = PyChannel( + predata.dtype, + queue_size, + nbytes, + "a1_to_a2", + "a1_to_a2") + a2_to_a1 = PyChannel( + predata.dtype, + queue_size, + nbytes, + "a2_to_a1", + "a2_to_a1") + a1_to_mp = PyChannel( + predata.dtype, + queue_size, + nbytes, + "a1_to_mp", + "a1_to_mp") + + builder = Builder() + + target_a1 = partial(bound_target_a1, loop, mp_to_a1, + a1_to_a2, a2_to_a1, a1_to_mp, builder) + target_a2 = partial(bound_target_a2, loop, a1_to_a2, a2_to_a1, + builder) + + a1 = Process(target=target_a1) + a2 = Process(target=target_a2) + a1.start() + a2.start() + + to_a1 = mp_to_a1.src_port + from_a1 = a1_to_mp.dst_port + + to_a1.start() + from_a1.start() + + expect_result = np.copy(predata) + expect_result[0] = (1 + 3 * loop) + + loop_start = datetime.now() + while loop > 0: + loop = loop - 1 + to_a1.send(predata) + predata = from_a1.recv() + loop_end = datetime.now() + print("py_shm_loop_with_py_multiprocess result = ", predata[0]) + if not np.array_equal(expect_result, predata): + print("expect: ", expect_result) + print("result: ", predata) + raise AssertionError() + + to_a1.join() + from_a1.join() + a1.terminate() + a2.terminate() + a1.join() + a2.join() + print("py_shm_loop_with_py_multiprocess timedelta =", + loop_end - loop_start) + + @unittest.skipIf(not SupportGRPCChannel, "Not support grpc channel.") + def test_grpcchannel(self): + from lava.magma.runtime.message_infrastructure import GetRPCChannel + from lava.magma.runtime.message_infrastructure \ + .multiprocessing \ + import MultiProcessing + + mp = MultiProcessing() + mp.start() + loop = self.loop_ + a1_to_a2 = GetRPCChannel( + '127.13.2.11', + 8001, + 'a1_to_a2', + 'a1_to_a2', 8) + a2_to_a1 = GetRPCChannel( + '127.13.2.12', + 8002, + 'a2_to_a1', + 'a2_to_a1', 8) + mp_to_a1 = GetRPCChannel( + '127.13.2.13', + 8003, + 'mp_to_a1', + 'mp_to_a1', 8) + a1_to_mp = GetRPCChannel( + '127.13.2.14', + 8004, + 'a1_to_mp', + 'a1_to_mp', 8) + + recv_port_fn = partial(bound_target_a1, loop, mp_to_a1, + a1_to_a2, a2_to_a1, a1_to_mp) + send_port_fn = partial(bound_target_a2, loop, a1_to_a2, a2_to_a1) + + builder1 = Builder() + builder2 = Builder() + mp.build_actor(recv_port_fn, builder1) + mp.build_actor(send_port_fn, builder2) + to_a1 = mp_to_a1.src_port + from_a1 = a1_to_mp.dst_port + to_a1.start() + from_a1.start() + data = prepare_data() + expect_result = prepare_data() + expect_result[0] = (1 + 3 * loop) + loop_start_time = datetime.now() + while loop: + to_a1.send(data) + data = from_a1.recv() + loop -= 1 + print("cpp_grpc_loop_with_cpp_multiprocess result = ", data[0]) + loop_end_time = datetime.now() + from_a1.join() + to_a1.join() + mp.stop() + mp.cleanup(True) + if not np.array_equal(expect_result, data): + print("expect: ", expect_result) + print("result: ", data) + raise AssertionError() + print("cpp_grpc_loop_with_cpp_multiprocess timedelta =", + loop_end_time - loop_start_time) + + @unittest.skipIf(not SupportFastDDSChannel, + "Not support FastDDS channel.") + def test_fastdds_channel(self): + from lava.magma.runtime.message_infrastructure import ( + GetDDSChannel, + DDSTransportType, + DDSBackendType, + ChannelQueueSize) + from lava.magma.runtime.message_infrastructure \ + .multiprocessing \ + import MultiProcessing + + mp = MultiProcessing() + mp.start() + loop = self.loop_ + a1_to_a2 = GetDDSChannel( + "a1_to_a2", + ChannelQueueSize, + DDSTransportType.DDSUDPv4, + DDSBackendType.FASTDDSBackend + ) + a2_to_a1 = GetDDSChannel( + "a2_to_a1", + ChannelQueueSize, + DDSTransportType.DDSUDPv4, + DDSBackendType.FASTDDSBackend + ) + mp_to_a1 = GetDDSChannel( + "mp_to_a1", + ChannelQueueSize, + DDSTransportType.DDSUDPv4, + DDSBackendType.FASTDDSBackend + ) + a1_to_mp = GetDDSChannel( + "a1_to_mp", + ChannelQueueSize, + DDSTransportType.DDSUDPv4, + DDSBackendType.FASTDDSBackend + ) + + recv_port_fn = partial(bound_target_a1, loop, mp_to_a1, + a1_to_a2, a2_to_a1, a1_to_mp) + send_port_fn = partial(bound_target_a2, loop, a1_to_a2, a2_to_a1) + + builder1 = Builder() + builder2 = Builder() + mp.build_actor(recv_port_fn, builder1) + mp.build_actor(send_port_fn, builder2) + to_a1 = mp_to_a1.src_port + from_a1 = a1_to_mp.dst_port + to_a1.start() + from_a1.start() + data = prepare_data() + expect_result = prepare_data() + expect_result[0] = (1 + 3 * loop) + loop_start_time = datetime.now() + while loop: + to_a1.send(data) + data = from_a1.recv() + loop -= 1 + print("cpp_fastdds_loop_with_cpp_multiprocess result = ", data[0]) + loop_end_time = datetime.now() + if not np.array_equal(expect_result, data): + print("expect: ", expect_result) + print("result: ", data) + raise AssertionError() + print("cpp_fastdds_loop_with_cpp_multiprocess timedelta =", + loop_end_time - loop_start_time) + mp.stop() + from_a1.join() + to_a1.join() + time.sleep(0.1) + + @unittest.skipIf(not SupportCycloneDDSChannel, + "Not support CycloneDDS channel.") + def test_cyclonedds_channel(self): + from lava.magma.runtime.message_infrastructure import ( + GetDDSChannel, + ChannelQueueSize, + DDSTransportType, + DDSBackendType) + from lava.magma.runtime.message_infrastructure \ + .multiprocessing \ + import MultiProcessing + + mp = MultiProcessing() + mp.start() + loop = self.loop_ + a1_to_a2 = GetDDSChannel( + "a1_to_a2", + ChannelQueueSize, + DDSTransportType.DDSUDPv4, + DDSBackendType.CycloneDDSBackend + ) + a2_to_a1 = GetDDSChannel( + "a2_to_a1", + ChannelQueueSize, + DDSTransportType.DDSUDPv4, + DDSBackendType.CycloneDDSBackend + ) + mp_to_a1 = GetDDSChannel( + "mp_to_a1", + ChannelQueueSize, + DDSTransportType.DDSUDPv4, + DDSBackendType.CycloneDDSBackend + ) + a1_to_mp = GetDDSChannel( + "a1_to_mp", + ChannelQueueSize, + DDSTransportType.DDSUDPv4, + DDSBackendType.CycloneDDSBackend + ) + + recv_port_fn = partial(bound_target_a1, loop, mp_to_a1, + a1_to_a2, a2_to_a1, a1_to_mp) + send_port_fn = partial(bound_target_a2, loop, a1_to_a2, a2_to_a1) + + builder1 = Builder() + builder2 = Builder() + mp.build_actor(recv_port_fn, builder1) + mp.build_actor(send_port_fn, builder2) + to_a1 = mp_to_a1.src_port + from_a1 = a1_to_mp.dst_port + to_a1.start() + from_a1.start() + data = prepare_data() + expect_result = prepare_data() + expect_result[0] = (1 + 3 * loop) + loop_start_time = datetime.now() + while loop: + to_a1.send(data) + data = from_a1.recv() + loop -= 1 + print("cpp_cyclonedds_loop_with_cpp_multiprocess result = ", data[0]) + loop_end_time = datetime.now() + if not np.array_equal(expect_result, data): + print("expect: ", expect_result) + print("result: ", data) + raise AssertionError() + print("cpp_cylonedds_loop_with_cpp_multiprocess timedelta =", + loop_end_time - loop_start_time) + mp.stop() + from_a1.join() + to_a1.join() + time.sleep(0.1) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/lava/magma/runtime/message_infrastructure/test_channel.py b/tests/lava/magma/runtime/message_infrastructure/test_channel.py new file mode 100644 index 000000000..8850d692f --- /dev/null +++ b/tests/lava/magma/runtime/message_infrastructure/test_channel.py @@ -0,0 +1,302 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + +import numpy as np +import unittest +from functools import partial +import time + +from lava.magma.runtime.message_infrastructure import ( + PURE_PYTHON_VERSION, + Channel, + SendPort, + RecvPort, + SupportGRPCChannel, + SupportFastDDSChannel, + SupportCycloneDDSChannel +) + + +def prepare_data(): + return np.random.random_sample((2, 4)) + + +const_data = prepare_data() + + +def send_proc(*args, **kwargs): + port = kwargs.pop("port") + if not isinstance(port, SendPort): + raise AssertionError() + port.start() + port.send(const_data) + port.join() + + +def recv_proc(*args, **kwargs): + port = kwargs.pop("port") + port.start() + if not isinstance(port, RecvPort): + raise AssertionError() + data = port.recv() + if not np.array_equal(data, const_data): + raise AssertionError() + port.join() + + +class Builder: + def build(self): + pass + + +def ddschannel_protocol(transfer_type, backend, topic_name): + from lava.magma.runtime.message_infrastructure import ( + GetDDSChannel, + ChannelQueueSize) + from lava.magma.runtime.message_infrastructure \ + .multiprocessing import MultiProcessing + mp = MultiProcessing() + mp.start() + dds_channel = GetDDSChannel( + topic_name, + ChannelQueueSize, + transfer_type, + backend) + + send_port = dds_channel.src_port + recv_port = dds_channel.dst_port + + recv_port_fn = partial(recv_proc, port=recv_port) + send_port_fn = partial(send_proc, port=send_port) + + builder1 = Builder() + builder2 = Builder() + mp.build_actor(recv_port_fn, builder1) + mp.build_actor(send_port_fn, builder2) + + time.sleep(0.1) + mp.stop() + mp.cleanup(True) + + +class TestChannel(unittest.TestCase): + + @unittest.skipIf(PURE_PYTHON_VERSION, "cpp msg lib test") + def test_shmemchannel(self): + from lava.magma.runtime.message_infrastructure \ + .MessageInfrastructurePywrapper import ChannelType + from lava.magma.runtime.message_infrastructure \ + .multiprocessing import MultiProcessing + from lava.magma.runtime.message_infrastructure \ + import ChannelQueueSize + + mp = MultiProcessing() + mp.start() + nbytes = np.prod(const_data.shape) * const_data.dtype.itemsize + name = 'test_shmem_channel' + + shmem_channel = Channel( + ChannelType.SHMEMCHANNEL, + ChannelQueueSize, + nbytes, + name, + name, + (2, 4), + const_data.dtype) + + send_port = shmem_channel.src_port + recv_port = shmem_channel.dst_port + + recv_port_fn = partial(recv_proc, port=recv_port) + send_port_fn = partial(send_proc, port=send_port) + + builder1 = Builder() + builder2 = Builder() + mp.build_actor(recv_port_fn, builder1) + mp.build_actor(send_port_fn, builder2) + + time.sleep(0.1) + mp.stop() + mp.cleanup(True) + + @unittest.skipIf(PURE_PYTHON_VERSION, "cpp msg lib test") + def test_single_process_shmemchannel(self): + from lava.magma.runtime.message_infrastructure \ + .MessageInfrastructurePywrapper import ChannelType + from lava.magma.runtime.message_infrastructure \ + import ChannelQueueSize + + predata = prepare_data() + nbytes = np.prod(predata.shape) * predata.dtype.itemsize + name = 'test_single_process_shmem_channel' + + shmem_channel = Channel( + ChannelType.SHMEMCHANNEL, + ChannelQueueSize, + nbytes, + name, + name, + (2, 4), + const_data.dtype) + + send_port = shmem_channel.src_port + recv_port = shmem_channel.dst_port + + send_port.start() + recv_port.start() + + send_port.send(predata) + resdata = recv_port.recv() + + if not np.array_equal(resdata, predata): + raise AssertionError() + + self.assertTrue(send_port.shape, (2, 4)) + self.assertTrue(recv_port.d_type, np.int32) + + send_port.join() + recv_port.join() + + @unittest.skipIf(PURE_PYTHON_VERSION, "cpp msg lib test") + def test_socketchannel(self): + from lava.magma.runtime.message_infrastructure \ + .MessageInfrastructurePywrapper import ChannelType + from lava.magma.runtime.message_infrastructure \ + .multiprocessing import MultiProcessing + from lava.magma.runtime.message_infrastructure \ + import ChannelQueueSize + + mp = MultiProcessing() + mp.start() + nbytes = np.prod(const_data.shape) * const_data.dtype.itemsize + name = 'test_socket_channel' + + socket_channel = Channel( + ChannelType.SOCKETCHANNEL, + ChannelQueueSize, + nbytes, + name, + name, + (2, 4), + const_data.dtype) + + send_port = socket_channel.src_port + recv_port = socket_channel.dst_port + + recv_port_fn = partial(recv_proc, port=recv_port) + send_port_fn = partial(send_proc, port=send_port) + + builder1 = Builder() + builder2 = Builder() + mp.build_actor(recv_port_fn, builder1) + mp.build_actor(send_port_fn, builder2) + + time.sleep(0.1) + mp.stop() + mp.cleanup(True) + + @unittest.skipIf(PURE_PYTHON_VERSION, "cpp msg lib test") + def test_single_process_socketchannel(self): + from lava.magma.runtime.message_infrastructure \ + .MessageInfrastructurePywrapper import ChannelType + from lava.magma.runtime.message_infrastructure import \ + ChannelQueueSize + predata = prepare_data() + nbytes = np.prod(predata.shape) * predata.dtype.itemsize + name = 'test_single_process_socket_channel' + + socket_channel = Channel( + ChannelType.SOCKETCHANNEL, + ChannelQueueSize, + nbytes, + name, + name, + (2, 4), + const_data.dtype) + + send_port = socket_channel.src_port + recv_port = socket_channel.dst_port + + send_port.start() + recv_port.start() + + send_port.send(predata) + resdata = recv_port.recv() + + if not np.array_equal(resdata, predata): + raise AssertionError() + + send_port.join() + recv_port.join() + + @unittest.skipIf(not SupportGRPCChannel, "Not support grpc channel.") + def test_grpcchannel(self): + from lava.magma.runtime.message_infrastructure import GetRPCChannel + from lava.magma.runtime.message_infrastructure \ + .multiprocessing import MultiProcessing + + mp = MultiProcessing() + mp.start() + name = 'test_grpc_channel' + url = '127.13.2.11' + port = 8003 + grpc_channel = GetRPCChannel( + url, + port, + name, + name, + 1) + + send_port = grpc_channel.src_port + recv_port = grpc_channel.dst_port + + recv_port_fn = partial(recv_proc, port=recv_port) + send_port_fn = partial(send_proc, port=send_port) + + builder1 = Builder() + builder2 = Builder() + mp.build_actor(recv_port_fn, builder1) + mp.build_actor(send_port_fn, builder2) + + time.sleep(0.1) + mp.stop() + mp.cleanup(True) + + @unittest.skipIf(not SupportFastDDSChannel, "Not support fastdds channel.") + def test_fastdds_channel_shm(self): + from lava.magma.runtime.message_infrastructure import DDSTransportType + from lava.magma.runtime.message_infrastructure import DDSBackendType + ddschannel_protocol(DDSTransportType.DDSSHM, + DDSBackendType.FASTDDSBackend, + "test_fastdds_channel_shm") + + @unittest.skipIf(not SupportFastDDSChannel, "Not support fastdds channel.") + def test_fastdds_channel_udpv4(self): + from lava.magma.runtime.message_infrastructure import DDSTransportType + from lava.magma.runtime.message_infrastructure import DDSBackendType + ddschannel_protocol(DDSTransportType.DDSUDPv4, + DDSBackendType.FASTDDSBackend, + "test_fastdds_channel_udpv4") + + @unittest.skipIf(not SupportCycloneDDSChannel, + "Not support cyclonedds channel.") + def test_cyclonedds_channel_shm(self): + from lava.magma.runtime.message_infrastructure import DDSTransportType + from lava.magma.runtime.message_infrastructure import DDSBackendType + ddschannel_protocol(DDSTransportType.DDSSHM, + DDSBackendType.CycloneDDSBackend, + "test_cyclonedds_shm") + + @unittest.skipIf(not SupportCycloneDDSChannel, + "Not support cyclonedds channel.") + def test_cyclonedds_channel_udpv4(self): + from lava.magma.runtime.message_infrastructure import DDSTransportType + from lava.magma.runtime.message_infrastructure import DDSBackendType + ddschannel_protocol(DDSTransportType.DDSUDPv4, + DDSBackendType.CycloneDDSBackend, + "test_cyclonedds_udpv4") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/lava/magma/runtime/message_infrastructure/test_channel_block.py b/tests/lava/magma/runtime/message_infrastructure/test_channel_block.py new file mode 100644 index 000000000..dfa0d4a07 --- /dev/null +++ b/tests/lava/magma/runtime/message_infrastructure/test_channel_block.py @@ -0,0 +1,81 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + +import numpy as np +import unittest +from functools import partial +import time + +from lava.magma.runtime.message_infrastructure import ( + PURE_PYTHON_VERSION, + Channel, + SendPort, + RecvPort +) + +QUEUE_SIZE = 10 + + +def generate_data(): + return np.random.random_sample((2, 4)) + + +def send_proc(*args, **kwargs): + port = kwargs.pop("port") + + if not isinstance(port, SendPort): + raise AssertionError() + port.start() + for _ in range(QUEUE_SIZE + 1): + data = generate_data() + port.send(data) + + +def recv_proc(*args, **kwargs): + port = kwargs.pop("port") + port.start() + if not isinstance(port, RecvPort): + raise AssertionError() + time.sleep(1) + for _ in range(QUEUE_SIZE + 1): + port.recv() + + +class TestChannelBlock(unittest.TestCase): + + @unittest.skipIf(PURE_PYTHON_VERSION, "cpp msg lib test") + def test_block(self): + from lava.magma.runtime.message_infrastructure \ + .MessageInfrastructurePywrapper import ChannelType + from lava.magma.runtime.message_infrastructure \ + .multiprocessing import MultiProcessing + + mp = MultiProcessing() + mp.start() + predata = generate_data() + nbytes = np.prod(predata.shape) * predata.dtype.itemsize + shmem_channel = Channel( + ChannelType.SHMEMCHANNEL, + QUEUE_SIZE, + nbytes, + "test_block", + "test_block", + (2, 4), + predata.dtype) + send_port = shmem_channel.src_port + recv_port = shmem_channel.dst_port + + recv_port_fn = partial(recv_proc, port=recv_port) + send_port_fn = partial(send_proc, port=send_port) + + mp.build_actor(recv_port_fn, None) + mp.build_actor(send_port_fn, None) + + time.sleep(1) + mp.stop() + mp.cleanup(True) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/lava/magma/runtime/message_infrastructure/test_exception.py b/tests/lava/magma/runtime/message_infrastructure/test_exception.py new file mode 100644 index 000000000..0b6727712 --- /dev/null +++ b/tests/lava/magma/runtime/message_infrastructure/test_exception.py @@ -0,0 +1,207 @@ +# Copyright (C) 2021-22 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + +import logging +import unittest +import numpy as np +from multiprocessing import shared_memory, Semaphore + +from lava.magma.core.decorator import implements, requires, tag +from lava.magma.core.model.py.model import PyLoihiProcessModel +from lava.magma.core.model.py.ports import PyOutPort, PyInPort +from lava.magma.core.model.py.type import LavaPyType +from lava.magma.core.process.ports.ports import OutPort, InPort +from lava.magma.core.process.process import AbstractProcess, LogConfig +from lava.magma.core.resources import CPU +from lava.magma.core.run_configs import Loihi1SimCfg +from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol +from lava.magma.core.run_conditions import RunSteps +from lava.magma.runtime.message_infrastructure import PURE_PYTHON_VERSION + + +# A minimal process with an OutPort +class P1(AbstractProcess): + def __init__(self): + super().__init__(log_config=LogConfig(level=logging.CRITICAL)) + self.out = OutPort(shape=(2,)) + + +# A minimal process with an InPort +class P2(AbstractProcess): + def __init__(self): + super().__init__(log_config=LogConfig(level=logging.CRITICAL)) + self.inp = InPort(shape=(2,)) + + +# A minimal process with an InPort +class P3(AbstractProcess): + def __init__(self): + super().__init__(log_config=LogConfig(level=logging.CRITICAL)) + self.inp = InPort(shape=(2,)) + + +# A minimal PyProcModel implementing P1 +@implements(proc=P1, protocol=LoihiProtocol) +@requires(CPU) +@tag('floating_pt') +class PyProcModel1(PyLoihiProcessModel): + out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, int) + + def run_spk(self): + if self.time_step > 1: + shm = shared_memory.SharedMemory(name='error_block') + _shm_ack.acquire() + err = np.ndarray((1,), buffer=shm.buf) + err[0] += 1 + _shm_ack.release() + shm.close() + + +# A minimal PyProcModel implementing P2 +@implements(proc=P2, protocol=LoihiProtocol) +@requires(CPU) +@tag('floating_pt') +class PyProcModel2(PyLoihiProcessModel): + inp: PyInPort = LavaPyType(PyInPort.VEC_DENSE, int) + + def run_spk(self): + if self.time_step > 1: + shm = shared_memory.SharedMemory(name='error_block') + _shm_ack.acquire() + err = np.ndarray((1,), buffer=shm.buf) + err[0] += 1 + _shm_ack.release() + shm.close() + + +# A minimal PyProcModel implementing P3 +@implements(proc=P3, protocol=LoihiProtocol) +@requires(CPU) +@tag('floating_pt') +class PyProcModel3(PyLoihiProcessModel): + inp: PyInPort = LavaPyType(PyInPort.VEC_DENSE, int) + + def run_spk(self): + ... + + +class TestExceptionHandling(unittest.TestCase): + @classmethod + def setUpClass(cls): + global _shm_ack # pylint: disable=W0601 + _shm_ack = Semaphore(1) + + @classmethod + def tearDownClass(cls): + del globals()['_shm_ack'] + + def setUp(self): + """ + Creates a shared memory block. + Runs as part of unit test method. + """ + error_message = np.zeros(shape=(1,)) + shm = shared_memory.SharedMemory(create=True, + size=error_message.nbytes, + name='error_block') + err = np.ndarray(error_message.shape, dtype=np.float64, buffer=shm.buf) + err[:] = error_message[:] + shm.close() + self.shm_name = shm.name + + def tearDown(self): + """ + Destroys the shared memory block. + Runs as part of unit test method. + """ + existing_shm = shared_memory.SharedMemory(name=self.shm_name) + existing_shm.unlink() + + @unittest.skipIf(PURE_PYTHON_VERSION, "cpp msg lib version") + def test_one_pm(self): + """Checks the forwarding of exceptions within a ProcessModel to the + runtime.""" + # Create an instance of P1 + proc = P1() + + run_steps = RunSteps(num_steps=1) + run_cfg = Loihi1SimCfg( + loglevel=logging.CRITICAL) + + # Run the network for 1 time step -> no exception + proc.run(condition=run_steps, run_cfg=run_cfg) + + # Run the network for another time step -> expect exception + proc.run(condition=run_steps, run_cfg=run_cfg) + + # Check that the error count has increased to 1 + existing_shm = shared_memory.SharedMemory(name='error_block') + res = np.copy(np.frombuffer(existing_shm.buf)) + existing_shm.close() + proc.stop() + self.assertEqual(res[0], 1) + + @unittest.skipIf(PURE_PYTHON_VERSION, "cpp msg lib version") + def test_two_pm(self): + """Checks the forwarding of exceptions within two ProcessModel to the + runtime.""" + # Create a sender instance of P1 and a receiver instance of P2 + sender = P1() + recv = P2() + + run_steps = RunSteps(num_steps=1) + run_cfg = Loihi1SimCfg( + loglevel=logging.CRITICAL) + + # Connect sender with receiver + sender.out.connect(recv.inp) + + # Run the network for 1 time step -> no exception + sender.run(condition=run_steps, run_cfg=run_cfg) + + # Run the network for another time step -> expect exception + sender.run(condition=run_steps, run_cfg=run_cfg) + + # Check that the error count has increased to 2 + existing_shm = shared_memory.SharedMemory(name=self.shm_name) + res = np.copy(np.frombuffer(existing_shm.buf)) + existing_shm.close() + sender.stop() + recv.stop() + self.assertEqual(res[0], 2) + + @unittest.skipIf(PURE_PYTHON_VERSION, "cpp msg lib version") + def test_three_pm(self): + """Checks the forwarding of exceptions within three ProcessModel to the + runtime.""" + # Create a sender instance of P1 and receiver instances of P2 and P3 + sender = P1() + recv1 = P2() + recv2 = P3() + + run_steps = RunSteps(num_steps=1) + run_cfg = Loihi1SimCfg( + loglevel=logging.CRITICAL) + + # Connect sender with receiver + sender.out.connect([recv1.inp, recv2.inp]) + + # Run the network for 1 time step -> no exception + sender.run(condition=run_steps, run_cfg=run_cfg) + + # Run the network for another time step -> expect exception + sender.run(condition=run_steps, run_cfg=run_cfg) + + # Check that the error count has increased to 2 + existing_shm = shared_memory.SharedMemory(name=self.shm_name) + res = np.copy(np.frombuffer(existing_shm.buf)) + existing_shm.close() + sender.stop() + recv1.stop() + recv2.stop() + self.assertEqual(res[0], 2) + + +if __name__ == '__main__': + unittest.main(buffer=False) diff --git a/tests/lava/magma/runtime/message_infrastructure/test_multiprocessing.py b/tests/lava/magma/runtime/message_infrastructure/test_multiprocessing.py new file mode 100644 index 000000000..a618d83c2 --- /dev/null +++ b/tests/lava/magma/runtime/message_infrastructure/test_multiprocessing.py @@ -0,0 +1,62 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ +import traceback +import unittest +import time +import numpy as np +from functools import partial + +from lava.magma.runtime.message_infrastructure import \ + PURE_PYTHON_VERSION + + +def nbytes_cal(shape, dtype): + return np.prod(shape) * np.dtype(dtype).itemsize + + +class Builder: + def build(self, i): + time.sleep(0.0001) + + +def target_fn(*args, **kwargs): + """ + Function to build and attach a system process to + + :param args: List Parameters to be passed onto the process + :param kwargs: Dict Parameters to be passed onto the process + :return: None + """ + try: + builder = kwargs.pop("builder") + idx = kwargs.pop("idx") + builder.build(idx) + return 0 + except Exception as e: + print("Encountered Fatal Exception: " + str(e)) + print("Traceback: ") + print(traceback.format_exc()) + raise e + + +class TestMultiprocessing(unittest.TestCase): + + @unittest.skipIf(PURE_PYTHON_VERSION, "cpp msg lib version") + def test_multiprocessing_actors(self): + from lava.magma.runtime.message_infrastructure.multiprocessing \ + import MultiProcessing + mp = MultiProcessing() + mp.start() + builder = Builder() + for i in range(5): + bound_target_fn = partial(target_fn, idx=i) + mp.build_actor(bound_target_fn, builder) + + time.sleep(0.1) + mp.stop() + mp.cleanup(True) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/lava/magma/runtime/message_infrastructure/test_selector.py b/tests/lava/magma/runtime/message_infrastructure/test_selector.py new file mode 100644 index 000000000..d848a00f7 --- /dev/null +++ b/tests/lava/magma/runtime/message_infrastructure/test_selector.py @@ -0,0 +1,124 @@ +import numpy as np +import unittest +from functools import partial +from lava.magma.runtime.message_infrastructure \ + import Selector +from lava.magma.runtime.message_infrastructure import ( + PURE_PYTHON_VERSION, + Channel) + + +class Builder: + def build(self, i): + pass + + +def prepare_data(): + arr1 = np.array([1] * 9990) + arr2 = np.array([1, 2, 3, 4, 5, + 6, 7, 8, 9, 0]) + return np.concatenate((arr2, arr1)) + + +def bound_target_a1(loop, actor_to_mp_0, actor_to_mp_1, + actor_to_mp_2, builder): + to_mp_0 = actor_to_mp_0.src_port + to_mp_1 = actor_to_mp_1.src_port + to_mp_2 = actor_to_mp_2.src_port + to_mp_0.start() + to_mp_1.start() + to_mp_2.start() + predata = prepare_data() + while loop > 0: + loop = loop - 1 + to_mp_0.send(predata) + to_mp_1.send(predata) + to_mp_2.send(predata) + to_mp_0.join() + to_mp_1.join() + to_mp_2.join() + + +class TestSelector(unittest.TestCase): + + def __init__(self, methodName: str = ...) -> None: + super().__init__(methodName) + self.loop_ = 1000 + + @unittest.skipIf(PURE_PYTHON_VERSION, "cpp msg lib test") + def test_selector(self): + from lava.magma.runtime.message_infrastructure \ + .MessageInfrastructurePywrapper import ChannelType + from lava.magma.runtime.message_infrastructure \ + .multiprocessing \ + import MultiProcessing + + loop = self.loop_ * 3 + mp = MultiProcessing() + mp.start() + predata = prepare_data() + queue_size = 1 + nbytes = np.prod(predata.shape) * predata.dtype.itemsize + selector = Selector() + actor_to_mp_0 = Channel( + ChannelType.SHMEMCHANNEL, + queue_size, + nbytes, + "actor_to_mp_0", + "actor_to_mp_0", + (2, 2), + np.int32) + actor_to_mp_1 = Channel( + ChannelType.SHMEMCHANNEL, + queue_size, + nbytes, + "actor_to_mp_1", + "actor_to_mp_1", + (2, 2), + np.int32) + actor_to_mp_2 = Channel( + ChannelType.SHMEMCHANNEL, + queue_size, + nbytes, + "actor_to_mp_2", + "actor_to_mp_2", + (2, 2), + np.int32) + + target_a1 = partial(bound_target_a1, self.loop_, actor_to_mp_0, + actor_to_mp_1, actor_to_mp_2) + + builder = Builder() + + mp.build_actor(target_a1, builder) # actor1 + + from_a0 = actor_to_mp_0.dst_port + from_a1 = actor_to_mp_1.dst_port + from_a2 = actor_to_mp_2.dst_port + + from_a0.start() + from_a1.start() + from_a2.start() + expect_result = predata * 3 * self.loop_ + recv_port_list = [from_a0, from_a1, from_a2] + channel_actions = [(recv_port, (lambda y: (lambda: y))( + recv_port)) for recv_port in recv_port_list] + real_result = np.array(0) + while loop > 0: + loop = loop - 1 + recv_port = selector.select(*channel_actions) + data = recv_port.recv() + real_result = real_result + data + if not np.array_equal(expect_result, real_result): + print("expect: ", expect_result) + print("result: ", real_result) + raise AssertionError() + from_a0.join() + from_a1.join() + from_a2.join() + mp.stop() + mp.cleanup(True) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/lava/magma/runtime/message_infrastructure/test_temp_channel.py b/tests/lava/magma/runtime/message_infrastructure/test_temp_channel.py new file mode 100644 index 000000000..b22094990 --- /dev/null +++ b/tests/lava/magma/runtime/message_infrastructure/test_temp_channel.py @@ -0,0 +1,98 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + +import numpy as np +import unittest +from functools import partial +import time + +from lava.magma.runtime.message_infrastructure import ( + create_channel, + PURE_PYTHON_VERSION, + getTempRecvPort, + getTempSendPort +) + + +loop_number = 1000 + + +def prepare_data(): + return np.random.random_sample((65536, 10)) + + +const_data = prepare_data() + + +def actor_stop(name): + pass + + +def recv_proc(*args, **kwargs): + port = kwargs.pop("port") + port.start() + for _ in range(loop_number): + path, recv_port = getTempRecvPort() + recv_port.start() + port.send(np.array([path])) + data = recv_port.recv() + recv_port.join() + if not np.array_equal(data, const_data): + raise AssertionError() + port.join() + + +def send_proc(*args, **kwargs): + port = kwargs.pop("port") + port.start() + for _ in range(loop_number): + path = port.recv() + send_port = getTempSendPort(str(path[0])) + send_port.start() + send_port.send(const_data) + send_port.join() + port.join() + + +class Builder: + def build(self): + pass + + +class TestTempChannel(unittest.TestCase): + + @unittest.skipIf(PURE_PYTHON_VERSION, "cpp msg lib version") + def test_tempchannel(self): + from lava.magma.runtime.message_infrastructure \ + .multiprocessing import MultiProcessing + mp = MultiProcessing() + mp.start() + name = 'test_temp_channel' + + shmem_channel = create_channel( + None, + name, + name, + const_data.shape, + const_data.dtype, + const_data.size) + + send_port = shmem_channel.src_port + recv_port = shmem_channel.dst_port + + recv_port_fn = partial(recv_proc, port=send_port) + send_port_fn = partial(send_proc, port=recv_port) + + builder1 = Builder() + builder2 = Builder() + mp.build_actor(recv_port_fn, builder1) + mp.build_actor(send_port_fn, builder2) + + time.sleep(0.1) + mp.stop() + mp.cleanup(True) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/lava/magma/runtime/test_async_protocol.py b/tests/lava/magma/runtime/test_async_protocol.py index c430282c9..98e02c857 100644 --- a/tests/lava/magma/runtime/test_async_protocol.py +++ b/tests/lava/magma/runtime/test_async_protocol.py @@ -1,7 +1,6 @@ # Copyright (C) 2022 Intel Corporation # SPDX-License-Identifier: LGPL 2.1 or later # See: https://spdx.org/licenses/ - import unittest import numpy as np diff --git a/tests/lava/magma/runtime/test_context_manager.py b/tests/lava/magma/runtime/test_context_manager.py index 03d9fbc7d..ab85c115f 100644 --- a/tests/lava/magma/runtime/test_context_manager.py +++ b/tests/lava/magma/runtime/test_context_manager.py @@ -11,7 +11,8 @@ from lava.magma.core.decorator import implements, requires from lava.magma.core.model.py.model import PyLoihiProcessModel from lava.magma.core.model.py.type import LavaPyType -from lava.magma.core.process.message_interface_enum import ActorType +from lava.magma.runtime.message_infrastructure.message_interface_enum \ + import ActorType from lava.magma.core.process.process import AbstractProcess from lava.magma.core.process.variable import Var from lava.magma.core.resources import CPU diff --git a/tests/lava/magma/runtime/test_exception_handling.py b/tests/lava/magma/runtime/test_exception_handling.py index 183b162da..36b54e0fb 100644 --- a/tests/lava/magma/runtime/test_exception_handling.py +++ b/tests/lava/magma/runtime/test_exception_handling.py @@ -15,6 +15,7 @@ from lava.magma.core.run_configs import Loihi1SimCfg from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol from lava.magma.core.run_conditions import RunSteps +from lava.magma.runtime.message_infrastructure import PURE_PYTHON_VERSION # A minimal process with an OutPort @@ -76,6 +77,8 @@ def run_spk(self): class TestExceptionHandling(unittest.TestCase): + + @unittest.skipIf(not PURE_PYTHON_VERSION, "support py version only") def test_one_pm(self): """Checks the forwarding of exceptions within a ProcessModel to the runtime.""" @@ -99,6 +102,7 @@ def test_one_pm(self): # 1 exception in the ProcessModel expected self.assertTrue('1 Exception(s) occurred' in str(exception)) + @unittest.skipIf(not PURE_PYTHON_VERSION, "support py version only") def test_two_pm(self): """Checks the forwarding of exceptions within two ProcessModel to the runtime.""" @@ -126,6 +130,7 @@ def test_two_pm(self): # 2 Exceptions in the ProcessModels expected self.assertTrue('2 Exception(s) occurred' in str(exception)) + @unittest.skipIf(not PURE_PYTHON_VERSION, "support py version only") def test_three_pm(self): """Checks the forwarding of exceptions within three ProcessModel to the runtime.""" diff --git a/tests/lava/magma/runtime/test_file_descriptors.py b/tests/lava/magma/runtime/test_file_descriptors.py new file mode 100644 index 000000000..89598cde0 --- /dev/null +++ b/tests/lava/magma/runtime/test_file_descriptors.py @@ -0,0 +1,62 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + +import unittest +import os +from lava.magma.core.run_conditions import RunSteps +from lava.proc.lif.process import LIF +from lava.magma.core.run_configs import Loihi1SimCfg +from time import sleep +from subprocess import run # nosec +from lava.magma.runtime.message_infrastructure import PURE_PYTHON_VERSION + + +def run_process(): + du = 10 + dv = 100 + vth = 4900 + + # Create processes + lif2 = LIF(shape=(2, ), + vth=vth, + dv=dv, + du=du, + bias_mant=0, + name='lif2') + + lif2.run(condition=RunSteps(num_steps=1), + run_cfg=Loihi1SimCfg(select_tag="fixed_pt")) + + lif2.stop() + + +def get_file_descriptor_usage(): + result = run("lsof 2>/dev/null | grep python | grep FIFO | wc -l", # noqa: S607, E501 + shell=True) # nosec + sleep(0.1) + return result.stdout + + +class TestFileDescriptors(unittest.TestCase): + num_iterations = 1000 + + @unittest.skipIf(os.name != "posix" or PURE_PYTHON_VERSION, + "Checking file descriptor only for POSIX systems.") + def test_file_descriptor_usage(self): + # Check initial state that file descriptor usage is zero + file_descriptor_usage = get_file_descriptor_usage() + self.assertEqual(file_descriptor_usage, None) + + for iteration in range(self.num_iterations): + # Run Process + run_process() + + # Check file descriptor usage after running processes + file_descriptor_usage = get_file_descriptor_usage() + self.assertEqual(file_descriptor_usage, None, + msg=f"Failed on iteration {iteration}") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/lava/magma/runtime/test_get_set_non_determinism.py b/tests/lava/magma/runtime/test_get_set_non_determinism.py index a10002a20..b5fce8f56 100644 --- a/tests/lava/magma/runtime/test_get_set_non_determinism.py +++ b/tests/lava/magma/runtime/test_get_set_non_determinism.py @@ -38,6 +38,7 @@ def run_spk(self): class TestNonDeterminismUpdate(unittest.TestCase): + def test_non_determinism_update(self): nb_runs = 10000 demo_process = DemoProcess(nb_runs=nb_runs) diff --git a/tests/lava/magma/runtime/test_pause_requested_from_model.py b/tests/lava/magma/runtime/test_pause_requested_from_model.py index 68559bdbd..19e949f41 100644 --- a/tests/lava/magma/runtime/test_pause_requested_from_model.py +++ b/tests/lava/magma/runtime/test_pause_requested_from_model.py @@ -195,6 +195,7 @@ def test_stop_request_from_model_in_post_mgmt_phase(self): e = time() self.assertTrue(e - s < 100, "") self.assertFalse(process.runtime._is_running) + process.stop() @unittest.skip def test_pause_request_from_hierarchical_model(self): diff --git a/tests/lava/magma/runtime/test_ref_var_ports.py b/tests/lava/magma/runtime/test_ref_var_ports.py index 9b2d63fa0..543c6cda8 100644 --- a/tests/lava/magma/runtime/test_ref_var_ports.py +++ b/tests/lava/magma/runtime/test_ref_var_ports.py @@ -94,7 +94,7 @@ def post_guard(self): def run_post_mgmt(self): if self.time_step > 1: - ref_data = np.array([5, 5, 5]) + self.time_step + ref_data = np.array([5, 5, 5], np.int32) + self.time_step self.ref1.write(ref_data) self.ref3.write(ref_data[:2]) # ensure write() has finished before moving on diff --git a/tests/lava/magma/runtime/test_runtime.py b/tests/lava/magma/runtime/test_runtime.py index d68ed471f..13777f102 100644 --- a/tests/lava/magma/runtime/test_runtime.py +++ b/tests/lava/magma/runtime/test_runtime.py @@ -7,10 +7,12 @@ from unittest.mock import Mock from lava.magma.compiler.executable import Executable -from lava.magma.core.process.message_interface_enum import ActorType +from lava.magma.runtime.message_infrastructure.message_interface_enum \ + import ActorType from lava.magma.core.resources import HeadNode, Loihi2System from lava.magma.compiler.node import Node, NodeConfig -from lava.magma.compiler.channels.watchdog import WatchdogManagerBuilder +from lava.magma.runtime.message_infrastructure.watchdog import \ + WatchdogManagerBuilder from lava.magma.runtime.runtime import Runtime diff --git a/tests/lava/magma/runtime/test_runtime_service.py b/tests/lava/magma/runtime/test_runtime_service.py index 77dad457e..7742d2dba 100644 --- a/tests/lava/magma/runtime/test_runtime_service.py +++ b/tests/lava/magma/runtime/test_runtime_service.py @@ -3,9 +3,10 @@ # See: https://spdx.org/licenses/ import unittest - import numpy as np -from lava.magma.compiler.channels.pypychannel import PyPyChannel +from multiprocessing.managers import SharedMemoryManager +from lava.magma.runtime.message_infrastructure \ + import create_channel as create_pychannel from lava.magma.core.decorator import implements from lava.magma.core.model.py.model import AbstractPyProcessModel from lava.magma.core.process.process import AbstractProcess @@ -23,14 +24,7 @@ def __init__(self, smm): def create_channel(smm: SharedMemoryManager, name: str): mock = MockInterface(smm=smm) - return PyPyChannel( - mock, - name, - name, - (1,), - np.int32, - 8, - ) + return create_pychannel(mock, name + "src", name + "dst", (1,), np.int32, 8) class SimpleSyncProtocol(AbstractSyncProtocol): @@ -75,8 +69,6 @@ def test_runtime_service_start_run(self): service_to_runtime = create_channel(smm, name="service_to_runtime") service_to_process = [create_channel(smm, name="service_to_process")] process_to_service = [create_channel(smm, name="process_to_service")] - runtime_to_service.dst_port.start() - service_to_runtime.src_port.start() pm.service_to_process = service_to_process[0].dst_port pm.process_to_service = process_to_service[0].src_port @@ -86,9 +78,9 @@ def test_runtime_service_start_run(self): rs.service_to_runtime = service_to_runtime.dst_port rs.service_to_process = [service_to_process[0].src_port] rs.process_to_service = [process_to_service[0].dst_port] + rs.start() rs.join() pm.join() - smm.shutdown() if __name__ == '__main__': diff --git a/tests/lava/proc/dense/test_stdp_sim.py b/tests/lava/proc/dense/test_stdp_sim.py index b04542c00..d72344fb7 100644 --- a/tests/lava/proc/dense/test_stdp_sim.py +++ b/tests/lava/proc/dense/test_stdp_sim.py @@ -175,8 +175,9 @@ def run_spk(self) -> None: s_out_y1: sends the post-synaptic spike times. s_out_y2: sends the graded third-factor reward signal. """ - - self.y1 = self.compute_post_synaptic_trace(self.s_out_buff) + # pylint: disable=W0201 + self.y1 = \ + self.compute_post_synaptic_trace(self.s_out_buff).astype(np.int32) super().run_spk() diff --git a/tests/lava/proc/io/test_dataloader.py b/tests/lava/proc/io/test_dataloader.py index be53492c7..6e72fae47 100644 --- a/tests/lava/proc/io/test_dataloader.py +++ b/tests/lava/proc/io/test_dataloader.py @@ -43,8 +43,9 @@ def select( class DummyDataset: - def __init__(self, shape: tuple) -> None: + def __init__(self, shape: tuple, dtype: np.dtype) -> None: self.shape = shape + self.dtype = dtype def __len__(self) -> int: return 10 @@ -52,6 +53,7 @@ def __len__(self) -> int: def __getitem__(self, id_: int) -> Tuple[np.ndarray, int]: data = np.arange(np.prod(self.shape)).reshape(self.shape) + id_ data = data % np.prod(self.shape) + data = data.astype(self.dtype) label = id_ return data, label @@ -90,9 +92,10 @@ def test_state_loader(self) -> None: shape = (5, 7) interval = 5 offset = 2 + dtype = np.int32 proc = DummyProc(shape) - dataloader = StateDataloader(dataset=DummyDataset(shape), + dataloader = StateDataloader(dataset=DummyDataset(shape, dtype), interval=interval, offset=offset) @@ -112,7 +115,7 @@ def test_state_loader(self) -> None: out_data = out.data.get() proc.stop() - dataset = DummyDataset(shape) + dataset = DummyDataset(shape, np.int32) for i in range(offset + 1, num_steps): id = (i - offset - 1) // interval data, ground_truth = dataset[id] @@ -136,7 +139,7 @@ def run_test( num_steps: int, ) -> None: dataloader = SpikeDataloader( - dataset=SpikeDataset(shape + (steps,)), + dataset=SpikeDataset(shape + (steps,), np.int32), interval=interval, offset=offset ) @@ -156,7 +159,7 @@ def run_test( out_data = out.data.get() dataloader.stop() - dataset = SpikeDataset(shape + (steps,)) + dataset = SpikeDataset(shape + (steps,), np.int32) for i in range(offset + 1, num_steps, interval): id = (i - offset - 1) // interval data, ground_truth = dataset[id] @@ -215,4 +218,4 @@ def test_spike_loader_more_steps(self) -> None: if __name__ == '__main__': - pass + unittest.main() diff --git a/tests/lava/proc/io/test_extractor.py b/tests/lava/proc/io/test_extractor.py index 141b0c251..20fee145d 100644 --- a/tests/lava/proc/io/test_extractor.py +++ b/tests/lava/proc/io/test_extractor.py @@ -20,12 +20,21 @@ from lava.magma.core.model.py.ports import PyOutPort from lava.magma.core.run_configs import Loihi2SimCfg from lava.magma.core.run_conditions import RunSteps, RunContinuous -from lava.magma.runtime.message_infrastructure.multiprocessing import \ - MultiProcessing -from lava.magma.compiler.channels.pypychannel import PyPyChannel, CspSendPort from lava.proc.io.extractor import Extractor, PyLoihiExtractorModel from lava.proc.io import utils +from lava.magma.runtime.message_infrastructure import SendPort as CspSendPort +from lava.magma.runtime.message_infrastructure import PURE_PYTHON_VERSION +if PURE_PYTHON_VERSION: + from lava.magma.runtime.message_infrastructure.py_multiprocessing \ + import MultiProcessing + from lava.magma.runtime.message_infrastructure.pypychannel \ + import PyPyChannel +else: + from lava.magma.runtime.message_infrastructure.multiprocessing \ + import MultiProcessing + from lava.magma.runtime.message_infrastructure import Channel as PyPyChannel + class Send(AbstractProcess): """Process that sends arbitrary dense data stored in a ring buffer. @@ -56,6 +65,7 @@ def run_spk(self) -> None: self.out_port.send(data) +@unittest.skipUnless(PURE_PYTHON_VERSION, "cppbackend to be fixed") class TestExtractor(unittest.TestCase): def test_init(self): """Test that the Extractor Process is instantiated correctly.""" @@ -136,6 +146,7 @@ def test_invalid_channel_config(self): Extractor(shape=out_shape, channel_config=channel_config) +@unittest.skipUnless(PURE_PYTHON_VERSION, "cppbackend to be fixed") class TestPyLoihiExtractorModel(unittest.TestCase): def test_init(self): """Test that the PyLoihiExtractorModel ProcessModel is instantiated @@ -144,7 +155,7 @@ def test_init(self): buffer_size = 10 multi_processing = MultiProcessing() - multi_processing.start() + multi_processing.init() channel = PyPyChannel(message_infrastructure=multi_processing, src_name="src", dst_name="dst", diff --git a/tests/lava/proc/io/test_injector.py b/tests/lava/proc/io/test_injector.py index a200d705d..e55a8110b 100644 --- a/tests/lava/proc/io/test_injector.py +++ b/tests/lava/proc/io/test_injector.py @@ -20,9 +20,17 @@ from lava.magma.core.model.py.ports import PyInPort from lava.magma.core.run_configs import Loihi2SimCfg from lava.magma.core.run_conditions import RunSteps, RunContinuous -from lava.magma.runtime.message_infrastructure.multiprocessing import \ - MultiProcessing -from lava.magma.compiler.channels.pypychannel import PyPyChannel, CspRecvPort +from lava.magma.runtime.message_infrastructure import PURE_PYTHON_VERSION +if PURE_PYTHON_VERSION: + from lava.magma.runtime.message_infrastructure.py_multiprocessing \ + import MultiProcessing + from lava.magma.runtime.message_infrastructure.pypychannel \ + import PyPyChannel +else: + from lava.magma.runtime.message_infrastructure.multiprocessing \ + import MultiProcessing + from lava.magma.runtime.message_infrastructure import Channel as PyPyChannel +from lava.magma.runtime.message_infrastructure import RecvPort as CspRecvPort from lava.proc.io.injector import Injector, PyLoihiInjectorModel from lava.proc.io import utils @@ -63,6 +71,7 @@ def run_spk(self) -> None: (self.time_step - 1) % self._buffer_size] = self.in_port.recv() +@unittest.skipUnless(PURE_PYTHON_VERSION, "cppbackend to be fixed") class TestInjector(unittest.TestCase): def test_init(self): """Test that the Injector Process is instantiated correctly.""" @@ -144,6 +153,7 @@ def test_invalid_channel_config(self): Injector(shape=out_shape, channel_config=channel_config) +@unittest.skipUnless(PURE_PYTHON_VERSION, "cppbackend to be fixed") class TestPyLoihiInjectorModel(unittest.TestCase): def test_init(self): """Test that the PyLoihiInjectorModel ProcessModel is instantiated @@ -152,7 +162,7 @@ def test_init(self): buffer_size = 10 multi_processing = MultiProcessing() - multi_processing.start() + multi_processing.init() channel = PyPyChannel(message_infrastructure=multi_processing, src_name="src", dst_name="dst", @@ -536,3 +546,7 @@ def test_run_continuous(self): # portion. np.testing.assert_equal(recv_var_data[:num_send // 10], send_data[:num_send // 10]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/lava/proc/monitor/test_monitors.py b/tests/lava/proc/monitor/test_monitors.py index 17844181d..341140d3f 100644 --- a/tests/lava/proc/monitor/test_monitors.py +++ b/tests/lava/proc/monitor/test_monitors.py @@ -45,9 +45,9 @@ def post_guard(self): def run_post_mgmt(self): if self.time_step > 1: - self.s = np.array([self.time_step]) - self.u = 2 * np.array([self.time_step]) - self.v = np.array([[1, 2], [3, 4]]) + self.s = np.array([self.time_step], dtype=np.int32) + self.u = 2 * np.array([self.time_step], dtype=np.int32) + self.v = np.array([[1, 2], [3, 4]], dtype=np.int32) class Monitors(unittest.TestCase): @@ -180,13 +180,13 @@ def test_monitor_collects_correct_data_from_2D_var(self): # Access the collected data with the names of monitor proc and var probe_data = data[some_proc.name][some_proc.v.name] + # Stop running + some_proc.stop() + # Check if the collected data match the expected data self.assertTrue(np.all(probe_data == np.tile(np.array([[1, 2], [3, 4]]), (num_steps, 1, 1)))) - # Stop running - some_proc.stop() - def test_monitor_collects_voltage_and_spike_data_from_lif_neuron(self): """Check if two different Monitor process can monitor voltage (Var) and s_out (OutPort) of a LIF neuron. Check the collected data with diff --git a/tests/lava/proc/sdn/test_models.py b/tests/lava/proc/sdn/test_models.py index 76572c531..be86f1ea1 100644 --- a/tests/lava/proc/sdn/test_models.py +++ b/tests/lava/proc/sdn/test_models.py @@ -260,3 +260,7 @@ def test_reconstruction_relu_float(self) -> None: if verbose: print(f'Max abs error = {error}') self.assertTrue(error < vth * (1 << spike_exp)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/lava/proc/sparse/test_models.py b/tests/lava/proc/sparse/test_models.py index a864459c8..9c9eaad3f 100644 --- a/tests/lava/proc/sparse/test_models.py +++ b/tests/lava/proc/sparse/test_models.py @@ -301,7 +301,6 @@ def test_weights_get(self): run_cfg = Loihi2SimCfg(select_tag='floating_pt') conn = Sparse(weights=weights_sparse) - sparse_net = create_network(inp, conn, weights_sparse) conn.run(condition=run_cond, run_cfg=run_cfg) weights_got = conn.weights.get() @@ -328,7 +327,6 @@ def test_weights_set(self): run_cfg = Loihi2SimCfg(select_tag='floating_pt') conn = Sparse(weights=weights_init_sparse) - sparse_net = create_network(inp, conn, weights_init_sparse) conn.run(condition=run_cond, run_cfg=run_cfg) new_weights_sparse = conn.weights.init.copy() @@ -451,7 +449,6 @@ def test_consistency_with_learning_dense_random_shape_dt(self): tag_1=weights.copy(), tag_2=weights.copy(), learning_rule=learning_rule) - dense_net = create_learning_network(pre, conn, post) run_cond = RunSteps(num_steps=simtime) run_cfg = Loihi2SimCfg(select_tag='floating_pt') @@ -469,7 +466,6 @@ def test_consistency_with_learning_dense_random_shape_dt(self): tag_1=weights_sparse.copy(), tag_2=weights_sparse.copy(), learning_rule=learning_rule) - sparse_net = create_learning_network(pre, conn, post) conn.run(condition=run_cond, run_cfg=run_cfg) tags_got_sparse = conn.tag_1.get() @@ -1450,3 +1446,7 @@ def test_bitacc_pm_recurrence_delay(self): rcfg = Loihi2SimCfg(select_tag='floating_pt') sparse.run(condition=rcnd, run_cfg=rcfg) sparse.stop() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/lava/tutorials/test_tutorials.py b/tests/lava/tutorials/test_tutorials.py index 3f752c081..ab1fa2a10 100644 --- a/tests/lava/tutorials/test_tutorials.py +++ b/tests/lava/tutorials/test_tutorials.py @@ -25,7 +25,7 @@ class TestTutorials(unittest.TestCase): def _execute_notebook( self, base_dir: str, path: str - ) -> ty.Tuple[ty.Type[nbformat.NotebookNode], ty.List[str]]: + ) -> int: """Execute a notebook via nbconvert and collect output. Parameters @@ -37,23 +37,22 @@ def _execute_notebook( Returns ------- - Tuple - (parsed nbformat.NotebookNode object, list of execution errors) + int + (return code) """ cwd = os.getcwd() dir_name, notebook = os.path.split(path) try: env = self._update_pythonpath(base_dir, dir_name) - nb = self._convert_and_execute_notebook(notebook, env) - errors = self._collect_errors_from_all_cells(nb) + result = self._convert_and_execute_notebook(notebook, env) + errors = self._collect_errors_from_all_cells(result) except Exception as e: - nb = None - errors = str(e) + errors = -1 finally: os.chdir(cwd) - return nb, errors + return errors def _update_pythonpath( self, base_dir: str, dir_name: str @@ -91,7 +90,7 @@ def _update_pythonpath( def _convert_and_execute_notebook( self, notebook: str, env: ty.Dict[str, str] - ) -> ty.Type[nbformat.NotebookNode]: + ): """Covert notebook and execute it. Parameters @@ -106,25 +105,24 @@ def _convert_and_execute_notebook( nb : nbformat.NotebookNode Notebook dict-like node with attribute-access """ - with tempfile.NamedTemporaryFile(mode="w+t", suffix=".ipynb") as fout: + with tempfile.NamedTemporaryFile(mode="w+t", suffix=".py") as fout: args = [ "jupyter", "nbconvert", "--to", - "notebook", - "--execute", - "--ExecutePreprocessor.timeout=-1", + "python", "--output", - fout.name, + fout.name[0:-3], notebook, ] subprocess.check_call(args, env=env) # nosec # noqa: S603 fout.seek(0) - return nbformat.read(fout, nbformat.current_nbformat) + return subprocess.run(["ipython", "-c", fout.read()], # noqa # nosec + env=env) # noqa # nosec def _collect_errors_from_all_cells( - self, nb: nbformat.NotebookNode + self, result ) -> ty.List[str]: """Collect errors from executed notebook. @@ -138,13 +136,9 @@ def _collect_errors_from_all_cells( List Collection of errors """ - errors = [] - for cell in nb.cells: - if "outputs" in cell: - for output in cell["outputs"]: - if output.output_type == "error": - errors.append(output) - return errors + if result.returncode != 0: + result.check_returncode() + return result.returncode def _run_notebook(self, notebook: str, e2e_tutorial: bool = False): """Run a specific notebook @@ -183,22 +177,10 @@ def _run_notebook(self, notebook: str, e2e_tutorial: bool = False): # If the notebook is found execute it and store any errors for notebook_name in discovered_notebooks: - nb, errors = self._execute_notebook( + errors = self._execute_notebook( str(tutorials_directory), notebook_name ) - errors_joined = ( - "\n".join(errors) if isinstance(errors, list) else errors - ) - if errors: - errors_record[notebook_name] = (errors_joined, nb) - - self.assertFalse( - errors_record, - "Failed to execute Jupyter Notebooks \ - with errors: \n {}".format( - errors_record - ), - ) + self.assertEqual(errors, 0) finally: os.chdir(cwd) diff --git a/tutorials/in_depth/tutorial03_process_models.ipynb b/tutorials/in_depth/tutorial03_process_models.ipynb index 198cb3562..9bfd48781 100644 --- a/tutorials/in_depth/tutorial03_process_models.ipynb +++ b/tutorials/in_depth/tutorial03_process_models.ipynb @@ -254,7 +254,8 @@ "\n", "run_cfg = Loihi1SimCfg()\n", "lif.run(condition=RunSteps(num_steps=10), run_cfg=run_cfg)\n", - "print(lif.v.get())" + "print(lif.v.get())\n", + "lif.stop()" ] }, { diff --git a/tutorials/in_depth/tutorial04_execution.ipynb b/tutorials/in_depth/tutorial04_execution.ipynb index 35544e1ac..8cc62da78 100644 --- a/tutorials/in_depth/tutorial04_execution.ipynb +++ b/tutorials/in_depth/tutorial04_execution.ipynb @@ -125,7 +125,8 @@ "lif = LIF(shape=(1,))\n", "\n", "# execute that Process for 42 time steps in simulation\n", - "lif.run(condition=RunSteps(num_steps=42), run_cfg=Loihi1SimCfg())" + "lif.run(condition=RunSteps(num_steps=42), run_cfg=Loihi1SimCfg())\n", + "lif.stop()" ] }, { @@ -160,7 +161,8 @@ "dense.a_out.connect(lif2.a_in)\n", "\n", "# execute Process lif2 and all Processes connected to it (dense, lif1)\n", - "lif2.run(condition=RunSteps(num_steps=42), run_cfg=Loihi1SimCfg())" + "lif2.run(condition=RunSteps(num_steps=42), run_cfg=Loihi1SimCfg())\n", + "lif2.stop()" ] }, { @@ -348,7 +350,7 @@ "source": [ "from lava.magma.runtime.runtime import Runtime\n", "from lava.magma.core.run_conditions import RunSteps\n", - "from lava.magma.core.process.message_interface_enum import ActorType\n", + "from lava.magma.runtime.message_infrastructure.message_interface_enum import ActorType\n", "\n", "# create and initialize a runtime\n", "mp = ActorType.MultiProcessing\n",