Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import TYPE_CHECKING

from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter import ( # noqa: E501
TreeSitterSegmenter,
)

if TYPE_CHECKING:
from tree_sitter import Language


CHUNK_QUERY = """
[
(subroutine) @subroutine
(function) @function
(program) @program
(module) @module
]
""".strip()


class FortranSegmenter(TreeSitterSegmenter):
"""Code segmenter for Fortran."""

def get_language(self) -> "Language":
from tree_sitter_languages import get_language

return get_language("fortran")

def get_chunk_query(self) -> str:
return CHUNK_QUERY

def make_line_comment(self, text: str) -> str:
return f"! {text}"
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from langchain_community.document_loaders.parsers.language.cpp import CPPSegmenter
from langchain_community.document_loaders.parsers.language.csharp import CSharpSegmenter
from langchain_community.document_loaders.parsers.language.elixir import ElixirSegmenter
from langchain_community.document_loaders.parsers.language.fortran import (
FortranSegmenter,
)
from langchain_community.document_loaders.parsers.language.go import GoSegmenter
from langchain_community.document_loaders.parsers.language.java import JavaSegmenter
from langchain_community.document_loaders.parsers.language.javascript import (
Expand Down Expand Up @@ -49,6 +52,12 @@
"ex": "elixir",
"exs": "elixir",
"sql": "sql",
"f": "fortran",
"f90": "fortran",
"f95": "fortran",
"f03": "fortran",
"f08": "fortran",
"for": "fortran",
}

LANGUAGE_SEGMENTERS: Dict[str, Any] = {
Expand All @@ -70,6 +79,7 @@
"php": PHPSegmenter,
"elixir": ElixirSegmenter,
"sql": SQLSegmenter,
"fortran": FortranSegmenter,
}

Language = Literal[
Expand Down Expand Up @@ -97,6 +107,7 @@
"perl",
"elixir",
"sql",
"fortran",
]


Expand All @@ -116,6 +127,7 @@ class LanguageParser(BaseBlobParser):
- C#: "csharp" (*)
- COBOL: "cobol"
- Elixir: "elixir"
- Fortran: "fortran" (*)
- Go: "go" (*)
- Java: "java" (*)
- JavaScript: "js" (requires package `esprima`)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import unittest

import pytest

from langchain_community.document_loaders.parsers.language.fortran import (
FortranSegmenter,
)


@pytest.mark.requires("tree_sitter", "tree_sitter_languages")
class TestFortranSegmenter(unittest.TestCase):
def setUp(self) -> None:
self.example_code = """program hello
implicit none
write(*,*) 'Hello, World!'
end program hello

subroutine greet(name)
implicit none
character(len=*), intent(in) :: name
write(*,*) 'Hello, ', name
end subroutine greet

function add(a, b) result(c)
implicit none
integer, intent(in) :: a, b
integer :: c
c = a + b
end function add

module math_ops
implicit none
contains
function multiply(x, y) result(z)
real :: x, y, z
z = x * y
end function multiply
end module math_ops"""

self.expected_simplified_code = """! Code for: program hello
! Code for: subroutine greet(name)
! Code for: function add(a, b) result(c)
! Code for: module math_ops"""

self.expected_extracted_code = [
(
"program hello\n implicit none\n write(*,*) "
"'Hello, World!'\nend program hello\n"
),
(
"subroutine greet(name)\n implicit none\n "
"character(len=*), intent(in) :: name\n write(*,*) "
"'Hello, ', name\nend subroutine greet\n"
),
(
"function add(a, b) result(c)\n implicit none\n "
"integer, intent(in) :: a, b\n integer :: c\n "
"c = a + b\nend function add\n"
),
(
"module math_ops\n implicit none\ncontains\n "
"function multiply(x, y) result(z)\n real :: x, y, z\n"
" z = x * y\n end function multiply\n"
"end module math_ops"
),
]

def test_is_valid(self) -> None:
self.assertTrue(FortranSegmenter("program test\nend program test").is_valid())
self.assertFalse(FortranSegmenter("a b c 1 2 3").is_valid())

def test_extract_functions_classes(self) -> None:
segmenter = FortranSegmenter(self.example_code)
extracted_code = segmenter.extract_functions_classes()
self.assertEqual(extracted_code, self.expected_extracted_code)

def test_simplify_code(self) -> None:
segmenter = FortranSegmenter(self.example_code)
simplified_code = segmenter.simplify_code()
self.assertEqual(simplified_code, self.expected_simplified_code)