diff --git a/libs/community/langchain_community/document_loaders/parsers/language/fortran.py b/libs/community/langchain_community/document_loaders/parsers/language/fortran.py new file mode 100644 index 000000000..503bb34c3 --- /dev/null +++ b/libs/community/langchain_community/document_loaders/parsers/language/fortran.py @@ -0,0 +1,33 @@ +from typing import TYPE_CHECKING + +from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter import ( # noqa: E501 + TreeSitterSegmenter, +) + +if TYPE_CHECKING: + from tree_sitter import Language + + +CHUNK_QUERY = """ + [ + (subroutine) @subroutine + (function) @function + (program) @program + (module) @module + ] +""".strip() + + +class FortranSegmenter(TreeSitterSegmenter): + """Code segmenter for Fortran.""" + + def get_language(self) -> "Language": + from tree_sitter_languages import get_language + + return get_language("fortran") + + def get_chunk_query(self) -> str: + return CHUNK_QUERY + + def make_line_comment(self, text: str) -> str: + return f"! {text}" diff --git a/libs/community/langchain_community/document_loaders/parsers/language/language_parser.py b/libs/community/langchain_community/document_loaders/parsers/language/language_parser.py index e1d4e5ec6..0844fd2d4 100644 --- a/libs/community/langchain_community/document_loaders/parsers/language/language_parser.py +++ b/libs/community/langchain_community/document_loaders/parsers/language/language_parser.py @@ -11,6 +11,9 @@ from langchain_community.document_loaders.parsers.language.cpp import CPPSegmenter from langchain_community.document_loaders.parsers.language.csharp import CSharpSegmenter from langchain_community.document_loaders.parsers.language.elixir import ElixirSegmenter +from langchain_community.document_loaders.parsers.language.fortran import ( + FortranSegmenter, +) from langchain_community.document_loaders.parsers.language.go import GoSegmenter from langchain_community.document_loaders.parsers.language.java import JavaSegmenter from langchain_community.document_loaders.parsers.language.javascript import ( @@ -49,6 +52,12 @@ "ex": "elixir", "exs": "elixir", "sql": "sql", + "f": "fortran", + "f90": "fortran", + "f95": "fortran", + "f03": "fortran", + "f08": "fortran", + "for": "fortran", } LANGUAGE_SEGMENTERS: Dict[str, Any] = { @@ -70,6 +79,7 @@ "php": PHPSegmenter, "elixir": ElixirSegmenter, "sql": SQLSegmenter, + "fortran": FortranSegmenter, } Language = Literal[ @@ -97,6 +107,7 @@ "perl", "elixir", "sql", + "fortran", ] @@ -116,6 +127,7 @@ class LanguageParser(BaseBlobParser): - C#: "csharp" (*) - COBOL: "cobol" - Elixir: "elixir" + - Fortran: "fortran" (*) - Go: "go" (*) - Java: "java" (*) - JavaScript: "js" (requires package `esprima`) diff --git a/libs/community/tests/unit_tests/document_loaders/parsers/language/test_fortran.py b/libs/community/tests/unit_tests/document_loaders/parsers/language/test_fortran.py new file mode 100644 index 000000000..a180e9db0 --- /dev/null +++ b/libs/community/tests/unit_tests/document_loaders/parsers/language/test_fortran.py @@ -0,0 +1,80 @@ +import unittest + +import pytest + +from langchain_community.document_loaders.parsers.language.fortran import ( + FortranSegmenter, +) + + +@pytest.mark.requires("tree_sitter", "tree_sitter_languages") +class TestFortranSegmenter(unittest.TestCase): + def setUp(self) -> None: + self.example_code = """program hello + implicit none + write(*,*) 'Hello, World!' +end program hello + +subroutine greet(name) + implicit none + character(len=*), intent(in) :: name + write(*,*) 'Hello, ', name +end subroutine greet + +function add(a, b) result(c) + implicit none + integer, intent(in) :: a, b + integer :: c + c = a + b +end function add + +module math_ops + implicit none +contains + function multiply(x, y) result(z) + real :: x, y, z + z = x * y + end function multiply +end module math_ops""" + + self.expected_simplified_code = """! Code for: program hello +! Code for: subroutine greet(name) +! Code for: function add(a, b) result(c) +! Code for: module math_ops""" + + self.expected_extracted_code = [ + ( + "program hello\n implicit none\n write(*,*) " + "'Hello, World!'\nend program hello\n" + ), + ( + "subroutine greet(name)\n implicit none\n " + "character(len=*), intent(in) :: name\n write(*,*) " + "'Hello, ', name\nend subroutine greet\n" + ), + ( + "function add(a, b) result(c)\n implicit none\n " + "integer, intent(in) :: a, b\n integer :: c\n " + "c = a + b\nend function add\n" + ), + ( + "module math_ops\n implicit none\ncontains\n " + "function multiply(x, y) result(z)\n real :: x, y, z\n" + " z = x * y\n end function multiply\n" + "end module math_ops" + ), + ] + + def test_is_valid(self) -> None: + self.assertTrue(FortranSegmenter("program test\nend program test").is_valid()) + self.assertFalse(FortranSegmenter("a b c 1 2 3").is_valid()) + + def test_extract_functions_classes(self) -> None: + segmenter = FortranSegmenter(self.example_code) + extracted_code = segmenter.extract_functions_classes() + self.assertEqual(extracted_code, self.expected_extracted_code) + + def test_simplify_code(self) -> None: + segmenter = FortranSegmenter(self.example_code) + simplified_code = segmenter.simplify_code() + self.assertEqual(simplified_code, self.expected_simplified_code)