Skip to content

Commit d736542

Browse files
authored
add share parameter to webmodule and add callable condition for AdaptiveTransform [skip ci] (#564)
1 parent 3e476f3 commit d736542

File tree

4 files changed

+34
-21
lines changed

4 files changed

+34
-21
lines changed

lazyllm/tools/rag/doc_processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def _create_nodes_impl(self, p_nodes, group_name):
6565
# NOTE transform.batch_forward will set children for p_nodes, but when calling
6666
# transform.batch_forward, p_nodes has been upsert in the store.
6767
t = self._node_groups[group_name]['transform']
68-
transform = AdaptiveTransform(t) if isinstance(t, list) or t.pattern else make_transform(t)
68+
transform = AdaptiveTransform(t) if isinstance(t, list) or t.pattern else make_transform(t, group_name)
6969
nodes = transform.batch_forward(p_nodes, group_name)
7070
self._store.update_nodes(nodes)
7171
return nodes
@@ -118,7 +118,7 @@ def _reparse_group_recursive(self, p_nodes: List[DocNode], cur_name: str, doc_id
118118
raise Exception(f"Failed to remove nodes for docs {doc_ids} group {cur_name} from store")
119119

120120
t = self._node_groups[cur_name]['transform']
121-
transform = AdaptiveTransform(t) if isinstance(t, list) or t.pattern else make_transform(t)
121+
transform = AdaptiveTransform(t) if isinstance(t, list) or t.pattern else make_transform(t, cur_name)
122122
nodes = transform.batch_forward(p_nodes, cur_name)
123123
# reparse need set global_metadata
124124
self._store.update_nodes(nodes)

lazyllm/tools/rag/transform.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import ABC, abstractmethod
2+
from copy import copy as lite_copy
23
from dataclasses import dataclass, field
34
import requests
45
import os
@@ -22,7 +23,7 @@ class TransformArgs():
2223
trans_node: Optional[bool] = None
2324
num_workers: int = 0
2425
kwargs: Dict = field(default_factory=dict)
25-
pattern: Optional[str] = None
26+
pattern: Optional[Union[str, Callable[[str], bool]]] = None
2627

2728
@staticmethod
2829
def from_dict(d):
@@ -73,6 +74,7 @@ def split_text_keep_separator(text: str, separator: str) -> List[str]:
7374
class NodeTransform(ABC):
7475
def __init__(self, num_workers: int = 0):
7576
self._number_workers = num_workers
77+
self._name = None
7678

7779
def batch_forward(
7880
self, documents: Union[DocNode, List[DocNode]], node_group: str, **kwargs
@@ -100,36 +102,43 @@ def impl(node: DocNode):
100102
def transform(self, document: DocNode, **kwargs) -> List[Union[str, DocNode]]:
101103
raise NotImplementedError('Not implemented')
102104

105+
def with_name(self, name: Optional[str], *, copy: bool = True) -> 'NodeTransform':
106+
if name is not None:
107+
if copy: return lite_copy(self).with_name(name, copy=False)
108+
self._name = name
109+
return self
110+
103111
def __call__(self, node: DocNode, **kwargs: Any) -> List[DocNode]:
104112
# Parent and child should not be set here.
105113
results = self.transform(node, **kwargs)
106114
if isinstance(results, (DocNode, str)): results = [results]
107115
return [DocNode(text=chunk) if isinstance(chunk, str) else chunk for chunk in results if chunk]
108116

109117

110-
def make_transform(t):
118+
def make_transform(t: Union[TransformArgs, Dict[str, Any]], group_name: Optional[str] = None) -> NodeTransform:
111119
if isinstance(t, dict): t = TransformArgs.from_dict(t)
112120
transform, trans_node, num_workers = t['f'], t['trans_node'], t['num_workers']
113121
num_workers = dict(num_workers=num_workers) if num_workers > 0 else dict()
114-
return (transform(**t['kwargs'], **num_workers)
115-
if isinstance(transform, type)
116-
else transform if isinstance(transform, NodeTransform)
117-
else FuncNodeTransform(transform, trans_node=trans_node, **num_workers))
122+
return (transform(**t['kwargs'], **num_workers).with_name(group_name, copy=False) if isinstance(transform, type)
123+
else transform.with_name(group_name) if isinstance(transform, NodeTransform)
124+
else FuncNodeTransform(transform, trans_node=trans_node, **num_workers).with_name(group_name, copy=False))
118125

119126

120127
class AdaptiveTransform(NodeTransform):
121-
def __init__(self, transforms: Union[List[TransformArgs], TransformArgs]):
122-
super().__init__(num_workers=0)
128+
def __init__(self, transforms: Union[List[Union[TransformArgs, Dict]], Union[TransformArgs, Dict]],
129+
num_workers: int = 0):
130+
super().__init__(num_workers=num_workers)
123131
if not isinstance(transforms, (tuple, list)): transforms = [transforms]
124132
self._transformers = [(t.get('pattern'), make_transform(t)) for t in transforms]
125133

126134
def transform(self, document: DocNode, **kwargs) -> List[Union[str, DocNode]]:
135+
if not isinstance(document, DocNode): LOG.warning(f'Invalud document type {type(document)} got')
127136
for pt, transform in self._transformers:
128-
if pt and not pt.startswith('*'): pt = os.path.join(str(os.cwd()), pt)
129-
if not isinstance(document, DocNode):
130-
LOG.warning(f'Invalud document type {type(document)} got')
131-
if not pt or fnmatch.fnmatch(document.docpath, pt):
137+
if pt and isinstance(pt, str) and not pt.startswith('*'): pt = os.path.join(str(os.cwd()), pt)
138+
if not pt or (callable(pt) and pt(document.docpath)) or (
139+
isinstance(pt, str) and fnmatch.fnmatch(document.docpath, pt)):
132140
return transform(document, **kwargs)
141+
LOG.warning(f'No transform found for document {document.docpath} with group name `{self._name}`')
133142
return []
134143

135144

lazyllm/tools/webpages/webmodule.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import time
1111
import re
1212
from pathlib import Path
13-
from typing import List, Union
13+
from typing import List, Union, Optional, Any, Dict
1414

1515
import lazyllm
1616
from lazyllm import LOG, globals, FileSystemQueue, OnlineChatModule, TrainableModule
@@ -34,10 +34,12 @@ class Mode:
3434
Refresh = 1
3535
Appendix = 2
3636

37-
def __init__(self, m, *, components=dict(), title='对话演示终端', port=None,
38-
history=[], text_mode=None, trace_mode=None, audio=False, stream=False,
39-
files_target=None, static_paths: Union[str, Path, List[str | Path]] = None,
40-
encode_files=False) -> None:
37+
def __init__(self, m: Any, *, components: Dict[Any, Any] = dict(), title: str = '对话演示终端',
38+
port: Optional[Union[int, range, tuple, list]] = None, history: List[Any] = [],
39+
text_mode: Optional[Mode] = None, trace_mode: Optional[Mode] = None, audio: bool = False,
40+
stream: bool = False, files_target: Optional[Union[Any, List[Any]]] = None,
41+
static_paths: Optional[Union[str, Path, List[Union[str, Path]]]] = None,
42+
encode_files: bool = False, share: bool = False) -> None:
4143
super().__init__()
4244
# Set the static directory of gradio so that gradio can access local resources in the directory
4345
if isinstance(static_paths, (str, Path)):
@@ -66,6 +68,7 @@ def __init__(self, m, *, components=dict(), title='对话演示终端', port=Non
6668
self.stream = stream
6769
self.files_target = files_target if isinstance(files_target, list) or files_target is None else [files_target]
6870
self.encode_files = encode_files
71+
self.share = share
6972
self.demo = self.init_web(components)
7073
self.url = None
7174
signal.signal(signal.SIGINT, self._signal_handler)
@@ -405,7 +408,7 @@ def _work(self):
405408
self.url = f'http://127.0.0.1:{port}'
406409
self.broadcast_url = f'http://0.0.0.0:{port}'
407410

408-
self.demo.queue().launch(server_name="0.0.0.0", server_port=port, prevent_thread_lock=True)
411+
self.demo.queue().launch(server_name="0.0.0.0", server_port=port, prevent_thread_lock=True, share=self.share)
409412
LOG.success('LazyLLM webmodule launched successfully: Running on: '
410413
f'{self.broadcast_url}, local URL: {self.url}')
411414

tests/basic_tests/test_document.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,8 @@ def test_register_with_pattern(self):
132132
TransformArgs(f=SentenceSplitter, pattern='*.txt', kwargs=dict(chunk_size=512, chunk_overlap=50)),
133133
dict(f=SentenceSplitter, kwargs=dict(chunk_size=256, chunk_overlap=25))])
134134
Document.create_node_group('AdaptiveChunk2', transform=AdaptiveTransform([
135-
dict(f=SentenceSplitter, pattern='*.txt', kwargs=dict(chunk_size=512, chunk_overlap=50)),
135+
dict(f=SentenceSplitter, pattern=(lambda x: x.endswith('.txt')),
136+
kwargs=dict(chunk_size=512, chunk_overlap=50)),
136137
TransformArgs(f=SentenceSplitter, pattern=None, kwargs=dict(chunk_size=256, chunk_overlap=25))]))
137138
doc = Document('rag_master')
138139
doc._impl._lazy_init()

0 commit comments

Comments
 (0)