-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathchunks.json
More file actions
4460 lines (4460 loc) · 414 KB
/
chunks.json
File metadata and controls
4460 lines (4460 loc) · 414 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
[
{
"id": "build.py::2",
"input_type": "file",
"content": "import os\nimport subprocess\nimport shutil\nimport toml\nimport sys\n\ndef run_command(command, cwd=None):\n result = subprocess.run(command, shell=True, cwd=cwd, check=True)\n return result\n\ndef npm_install():\n print(\"Running npm install\")\n run_command(\"npm install\", cwd=\"ell-studio\")\n\n\ndef npm_build():\n print(\"Running npm build\")\n run_command(\"npm run build\", cwd=\"ell-studio\")\n print(\"Copying static files\")\n source_dir = os.path.join(\"ell-studio\", \"build\")\n target_dir = os.path.join(\"src\", \"ell\", \"studio\", \"static\")\n shutil.rmtree(target_dir, ignore_errors=True)\n shutil.copytree(source_dir, target_dir)\n print(f\"Copied static files from {source_dir} to {target_dir}\")\n\n\ndef get_ell_version():\n pyproject_path = \"pyproject.toml\"\n pyproject_data = toml.load(pyproject_path)\n return pyproject_data[\"tool\"][\"poetry\"][\"version\"]\n\n\ndef run_pytest():\n print(\"Running pytest\")\n try:\n run_command(\"pytest\", cwd=\"tests\")\n except subprocess.CalledProcessError:\n print(\"Pytest failed. Aborting build.\")\n sys.exit(1)\n\n\ndef run_all_examples():\n print(\"Running all examples\")\n try:\n run_command(\"python run_all_examples.py -w 16\", cwd=\"tests\")\n except subprocess.CalledProcessError:\n print(\"Some examples failed. Please review the output above.\")\n user_input = input(\"Do you want to continue with the build? (y/n): \").lower()\n if user_input != 'y':\n print(\"Aborting build.\")\n sys.exit(1)\n\n\ndef main():\n ell_version = get_ell_version()\n os.environ['REACT_APP_ELL_VERSION'] = ell_version\n npm_install()\n npm_build()\n run_pytest()\n run_all_examples()\n print(\"Build completed successfully.\")\n\n\nif __name__ == \"__main__\":\n main()",
"filepath": "build.py",
"metadata": {
"file_path": "build.py",
"file_name": "build.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 427,
"span_ids": [
"impl",
"npm_install",
"run_command",
"get_ell_version",
"npm_build",
"imports",
"run_pytest",
"run_all_examples",
"main"
],
"start_line": 1,
"end_line": 66,
"community": null
},
"node_id": "build.py::2"
},
{
"id": "0.1.0\\autostreamprevention.py::1",
"input_type": "file",
"content": "import openai\nimport os\n\n# Define the function to stream the response\ndef stream_openai_response(prompt):\n try:\n # Make the API call\n response = openai.chat.completions.create(\n model=\"o1-mini\", # Specify the model\n messages=[{\"role\": \"user\", \"content\": prompt}],\n stream=True # Enable streaming\n )\n\n # Stream the response\n for chunk in response:\n if chunk.choices[0].delta.get(\"content\"):\n print(chunk.choices[0].delta.content, end=\"\", flush=True)\n\n print() # Print a newline at the end\n\n except Exception as e:\n print(f\"An error occurred: {e}\")\n\n# Example usage\nprompt = \"Tell me a short joke.\"\nstream_openai_response(prompt)",
"filepath": "docs\\ramblings\\0.1.0\\autostreamprevention.py",
"metadata": {
"file_path": "docs\\ramblings\\0.1.0\\autostreamprevention.py",
"file_name": "autostreamprevention.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 169,
"span_ids": [
"stream_openai_response",
"impl",
"imports"
],
"start_line": 1,
"end_line": 26,
"community": null
},
"node_id": "0.1.0\\autostreamprevention.py::1"
},
{
"id": "0.1.0\\cem.py::1",
"input_type": "file",
"content": "import gym\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport numpy as np\nfrom gym.vector import AsyncVectorEnv\nimport random\n\n# Set random seeds for reproducibility\nSEED = 42\nrandom.seed(SEED)\nnp.random.seed(SEED)\ntorch.manual_seed(SEED)\n\n# Hyperparameters\nNUM_ENVIRONMENTS = 4 # Reduced for simplicity\nNUM_ITERATIONS = 50 # Number of training iterations\nTRAJECTORIES_PER_ITER = 100 # Total number of trajectories per iteration\nELITE_PERCENT = 10 # Top k% trajectories to select\nLEARNING_RATE = 1e-3\nBATCH_SIZE = 64\nMAX_STEPS = 500 # Max steps per trajectory\nENV_NAME = 'CartPole-v1'",
"filepath": "docs\\ramblings\\0.1.0\\cem.py",
"metadata": {
"file_path": "docs\\ramblings\\0.1.0\\cem.py",
"file_name": "cem.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 168,
"span_ids": [
"imports"
],
"start_line": 1,
"end_line": 23,
"community": null
},
"node_id": "0.1.0\\cem.py::1"
},
{
"id": "0.1.0\\cem.py::2",
"input_type": "file",
"content": " # Gym environment\n\n# Define the Policy Network\nclass PolicyNetwork(nn.Module):\n def __init__(self, state_dim, action_dim, hidden_dim=128):\n super(PolicyNetwork, self).__init__()\n self.fc = nn.Sequential(\n nn.Linear(state_dim, hidden_dim),\n nn.ReLU(),\n nn.Linear(hidden_dim, hidden_dim),\n nn.ReLU(),\n nn.Linear(hidden_dim, action_dim)\n )\n\n def forward(self, state):\n logits = self.fc(state)\n return logits\n\n def get_action(self, state):\n logits = self.forward(state)\n action_probs = torch.softmax(logits, dim=-1)\n action = torch.multinomial(action_probs, num_samples=1)\n return action.squeeze(-1)\n\n# Function to create multiple environments\ndef make_env(env_name, seed):\n def _init():\n env = gym.make(env_name)\n return env\n return _init",
"filepath": "docs\\ramblings\\0.1.0\\cem.py",
"metadata": {
"file_path": "docs\\ramblings\\0.1.0\\cem.py",
"file_name": "cem.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 192,
"span_ids": [
"PolicyNetwork.get_action",
"PolicyNetwork.forward",
"imports",
"PolicyNetwork",
"make_env",
"PolicyNetwork.__init__"
],
"start_line": 23,
"end_line": 52,
"community": null
},
"node_id": "0.1.0\\cem.py::2"
},
{
"id": "0.1.0\\cem.py::3",
"input_type": "file",
"content": "def collect_trajectories(envs, policy, num_trajectories, max_steps):\n trajectories = []\n num_envs = envs.num_envs\n\n # Handle the return type of reset()\n reset_output = envs.reset()\n if isinstance(reset_output, tuple) or isinstance(reset_output, list):\n obs = reset_output[0] # Extract observations\n else:\n obs = reset_output\n\n done_envs = [False] * num_envs\n steps = 0\n\n # Initialize storage for states, actions, and rewards per environment\n env_states = [[] for _ in range(num_envs)]\n env_actions = [[] for _ in range(num_envs)]\n env_rewards = [0.0 for _ in range(num_envs)]\n total_collected = 0\n # ... other code",
"filepath": "docs\\ramblings\\0.1.0\\cem.py",
"metadata": {
"file_path": "docs\\ramblings\\0.1.0\\cem.py",
"file_name": "cem.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 175,
"span_ids": [
"collect_trajectories"
],
"start_line": 54,
"end_line": 72,
"community": null
},
"node_id": "0.1.0\\cem.py::3"
},
{
"id": "0.1.0\\cem.py::4",
"input_type": "file",
"content": "def collect_trajectories(envs, policy, num_trajectories, max_steps):\n # ... other code\n\n while total_collected < num_trajectories and steps < max_steps:\n # Convert observations to tensor efficiently\n try:\n # Ensure 'obs' is a NumPy array\n if not isinstance(obs, np.ndarray):\n print(f\"Unexpected type for observations: {type(obs)}\")\n raise ValueError(\"Observations are not a NumPy array.\")\n\n # Convert observations to tensor using from_numpy for efficiency\n obs_tensor = torch.from_numpy(obs).float()\n # Ensure the observation dimension matches expected\n assert obs_tensor.shape[1] == 4, f\"Expected observation dimension 4, got {obs_tensor.shape[1]}\"\n except Exception as e:\n print(f\"Error converting observations to tensor at step {steps}: {e}\")\n print(f\"Observations: {obs}\")\n raise e\n\n with torch.no_grad():\n actions = policy.get_action(obs_tensor).cpu().numpy()\n\n # Unpack step based on Gym version\n try:\n # For Gym versions >=0.26, step returns five values\n next_obs, rewards, dones, truncs, infos = envs.step(actions)\n except ValueError:\n # For older Gym versions, step returns four values\n next_obs, rewards, dones, infos = envs.step(actions)\n truncs = [False] * len(dones) # Assume no truncations if not provided\n\n # Handle the reset output of step()\n if isinstance(next_obs, tuple) or isinstance(next_obs, list):\n next_obs = next_obs[0] # Extract observations\n\n # Ensure infos is a list\n if not isinstance(infos, list):\n infos = [{} for _ in range(num_envs)] # Default to empty dicts\n\n for i in range(num_envs):\n if not done_envs[i]:\n # Check if obs[i] has the correct shape\n if len(obs[i]) != 4:\n print(f\"Unexpected observation shape for env {i}: {obs[i]}\")\n continue # Skip this step for the problematic environment\n\n env_states[i].append(obs[i])\n env_actions[i].append(actions[i])\n env_rewards[i] += rewards[i]\n if dones[i] or truncs[i]:\n # Extract reward from infos\n if isinstance(infos[i], dict):\n episode_info = infos[i].get('episode', {})\n traj_reward = episode_info.get('r') if 'r' in episode_info else env_rewards[i]\n else:\n # Handle cases where infos[i] is not a dict\n traj_reward = env_rewards[i]\n print(f\"Warning: infos[{i}] is not a dict. Received type: {type(infos[i])}\")\n\n trajectories.append({\n 'states': env_states[i],\n 'actions': env_actions[i],\n 'reward': traj_reward\n })\n total_collected += 1\n env_states[i] = []\n env_actions[i] = []\n env_rewards[i] = 0.0\n done_envs[i] = True\n\n obs = next_obs\n steps += 1\n\n # Reset environments that are done\n if any(done_envs):\n indices = [i for i, done in enumerate(done_envs) if done]\n if total_collected < num_trajectories:\n for i in indices:\n try:\n # Directly reset the environment\n reset_output = envs.envs[i].reset()\n if isinstance(reset_output, tuple) or isinstance(reset_output, list):\n # For Gym versions where reset returns (obs, info)\n obs[i] = reset_output[0]\n else:\n # For Gym versions where reset returns only obs\n obs[i] = reset_output\n done_envs[i] = False\n except Exception as e:\n print(f\"Error resetting environment {i}: {e}\")\n # Optionally, handle the failure (e.g., retry, terminate the environment)\n done_envs[i] = False # Prevent infinite loop\n\n return trajectories",
"filepath": "docs\\ramblings\\0.1.0\\cem.py",
"metadata": {
"file_path": "docs\\ramblings\\0.1.0\\cem.py",
"file_name": "cem.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 872,
"span_ids": [
"collect_trajectories"
],
"start_line": 74,
"end_line": 165,
"community": null
},
"node_id": "0.1.0\\cem.py::4"
},
{
"id": "0.1.0\\cem.py::5",
"input_type": "file",
"content": "def select_elite(trajectories, percentile=ELITE_PERCENT):\n rewards = [traj['reward'] for traj in trajectories]\n if not rewards:\n return []\n reward_threshold = np.percentile(rewards, 100 - percentile)\n elite_trajectories = [traj for traj in trajectories if traj['reward'] >= reward_threshold]\n return elite_trajectories\n\n# Function to create training dataset from elite trajectories\ndef create_training_data(elite_trajectories):\n states = []\n actions = []\n for traj in elite_trajectories:\n states.extend(traj['states'])\n actions.extend(traj['actions'])\n if not states or not actions:\n return None, None\n # Convert lists to NumPy arrays first for efficiency\n states = np.array(states, dtype=np.float32)\n actions = np.array(actions, dtype=np.int64)\n # Convert to PyTorch tensors\n states = torch.from_numpy(states)\n actions = torch.from_numpy(actions)\n return states, actions",
"filepath": "docs\\ramblings\\0.1.0\\cem.py",
"metadata": {
"file_path": "docs\\ramblings\\0.1.0\\cem.py",
"file_name": "cem.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 209,
"span_ids": [
"create_training_data",
"select_elite"
],
"start_line": 168,
"end_line": 191,
"community": null
},
"node_id": "0.1.0\\cem.py::5"
},
{
"id": "0.1.0\\cem.py::6",
"input_type": "file",
"content": "# Main execution code\nif __name__ == '__main__':\n # Initialize environments\n env_fns = [make_env(ENV_NAME, SEED + i) for i in range(NUM_ENVIRONMENTS)]\n envs = AsyncVectorEnv(env_fns)\n\n # Get environment details\n dummy_env = gym.make(ENV_NAME)\n state_dim = dummy_env.observation_space.shape[0]\n action_dim = dummy_env.action_space.n\n dummy_env.close()\n\n # Initialize policy network and optimizer\n policy = PolicyNetwork(state_dim, action_dim)\n optimizer = optim.Adam(policy.parameters(), lr=LEARNING_RATE)\n criterion = nn.CrossEntropyLoss()\n\n # Training Loop\n for iteration in range(1, NUM_ITERATIONS + 1):\n try:\n # Step 1: Collect Trajectories\n trajectories = collect_trajectories(envs, policy, TRAJECTORIES_PER_ITER, MAX_STEPS)\n except Exception as e:\n print(f\"Error during trajectory collection at iteration {iteration}: {e}\")\n break\n\n # Step 2: Select Elite Trajectories\n elite_trajectories = select_elite(trajectories, ELITE_PERCENT)\n\n if len(elite_trajectories) == 0:\n print(f\"Iteration {iteration}: No elite trajectories found. Skipping update.\")\n continue\n\n # Step 3: Create Training Data\n states, actions = create_training_data(elite_trajectories)\n\n if states is None or actions is None:\n print(f\"Iteration {iteration}: No training data available. Skipping update.\")\n continue\n\n # Step 4: Behavioral Cloning (Policy Update)\n dataset_size = states.size(0)\n indices = np.arange(dataset_size)\n np.random.shuffle(indices)\n\n for start in range(0, dataset_size, BATCH_SIZE):\n end = start + BATCH_SIZE\n batch_indices = indices[start:end]\n batch_states = states[batch_indices]\n batch_actions = actions[batch_indices]\n\n optimizer.zero_grad()\n logits = policy(batch_states)\n loss = criterion(logits, batch_actions)\n loss.backward()\n optimizer.step()\n\n # Step 5: Evaluate Current Policy\n avg_reward = np.mean([traj['reward'] for traj in elite_trajectories])\n print(f\"Iteration {iteration}: Elite Trajectories: {len(elite_trajectories)}, Average Reward: {avg_reward:.2f}\")\n\n # Close environments\n envs.close()\n\n # Testing the Trained Policy\n def test_policy(policy, env_name=ENV_NAME, episodes=5, max_steps=500):\n env = gym.make(env_name)\n total_rewards = []\n for episode in range(episodes):\n obs, _ = env.reset()\n done = False\n episode_reward = 0\n for _ in range(max_steps):\n obs_tensor = torch.from_numpy(obs).float().unsqueeze(0)\n with torch.no_grad():\n action = policy.get_action(obs_tensor).item()\n obs, reward, done, info, _ = env.step(action)\n episode_reward += reward\n if done:\n break\n total_rewards.append(episode_reward)\n print(f\"Test Episode {episode + 1}: Reward: {episode_reward}\")\n env.close()\n print(f\"Average Test Reward over {episodes} episodes: {np.mean(total_rewards):.2f}\")\n\n # Run the test\n test_policy(policy)",
"filepath": "docs\\ramblings\\0.1.0\\cem.py",
"metadata": {
"file_path": "docs\\ramblings\\0.1.0\\cem.py",
"file_name": "cem.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 717,
"span_ids": [
"create_training_data",
"impl:22"
],
"start_line": 193,
"end_line": 279,
"community": null
},
"node_id": "0.1.0\\cem.py::6"
},
{
"id": "0.1.0\\context_versioning.py::1",
"input_type": "file",
"content": "import inspect\nimport ast\nfrom contextlib import contextmanager\n\n@contextmanager\ndef context():\n # Get the current frame\n frame = inspect.currentframe()\n try:\n # Get the caller's frame\n caller_frame = frame.f_back.f_back\n # Get the filename and line number where the context manager is called\n filename = caller_frame.f_code.co_filename\n lineno = caller_frame.f_lineno\n\n # Read the source code from the file\n with open(filename, 'r') as f:\n source = f.read()\n\n # Parse the source code into an AST\n parsed = ast.parse(source, filename)\n # print(source)\n # Find the 'with' statement at the given line number\n class WithVisitor(ast.NodeVisitor):\n def __init__(self, target_lineno):\n self.target_lineno = target_lineno\n self.with_node = None\n\n def visit_With(self, node):\n if node.lineno <= self.target_lineno <= node.end_lineno:\n self.with_node = node\n self.generic_visit(node)\n\n visitor = WithVisitor(lineno)\n visitor.visit(parsed)\n\n # print(parsed, source)\n if visitor.with_node:\n # Extract the source code of the block inside 'with'\n start = visitor.with_node.body[0].lineno\n end = visitor.with_node.body[-1].end_lineno\n block_source = '\\n'.join(source.splitlines()[start-1:end])\n print(\"Source code inside 'with' block:\")\n print(block_source)\n else:\n print(\"Could not find the 'with' block.\")\n\n # Yield control to the block inside 'with'\n yield\n finally:\n # Any cleanup can be done here\n pass\n\nfrom context_versioning import context\n# Example usage\nif __name__ == \"__main__\":\n with context():\n x = 10\n y = x * 2\n print(y)",
"filepath": "docs\\ramblings\\0.1.0\\context_versioning.py",
"metadata": {
"file_path": "docs\\ramblings\\0.1.0\\context_versioning.py",
"file_name": "context_versioning.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 405,
"span_ids": [
"impl",
"imports",
"context"
],
"start_line": 2,
"end_line": 63,
"community": null
},
"node_id": "0.1.0\\context_versioning.py::1"
},
{
"id": "0.1.0\\cpbo.py::1",
"input_type": "file",
"content": "import gym\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport numpy as np\nfrom collections import namedtuple\nfrom torch.utils.data import DataLoader, TensorDataset\n\n# Define a simple policy network\nclass PolicyNetwork(nn.Module):\n def __init__(self, state_dim, action_dim, hidden_dim=128):\n super(PolicyNetwork, self).__init__()\n self.network = nn.Sequential(\n nn.Linear(state_dim, hidden_dim),\n nn.ReLU(),\n nn.Linear(hidden_dim, hidden_dim),\n nn.ReLU(),\n nn.Linear(hidden_dim, action_dim),\n nn.Softmax(dim=-1) # Output action probabilities\n )\n\n def forward(self, x):\n return self.network(x)",
"filepath": "docs\\ramblings\\0.1.0\\cpbo.py",
"metadata": {
"file_path": "docs\\ramblings\\0.1.0\\cpbo.py",
"file_name": "cpbo.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 151,
"span_ids": [
"PolicyNetwork.forward",
"PolicyNetwork.__init__",
"imports",
"PolicyNetwork"
],
"start_line": 1,
"end_line": 23,
"community": null
},
"node_id": "0.1.0\\cpbo.py::1"
},
{
"id": "0.1.0\\cpbo.py::2",
"input_type": "file",
"content": "# Function to collect trajectories\ndef collect_trajectories(env, policy, num_episodes, device):\n trajectories = []\n Episode = namedtuple('Episode', ['states', 'actions', 'rewards'])\n\n for episode_num in range(num_episodes):\n states = []\n actions = []\n rewards = []\n # Handle Gym's updated reset() API\n state, info = env.reset(seed=42 + episode_num) # Optional: set seed for reproducibility\n done = False\n\n while not done:\n state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)\n with torch.no_grad():\n action_probs = policy(state_tensor)\n action_dist = torch.distributions.Categorical(action_probs)\n action = action_dist.sample().item()\n\n # Handle Gym's updated step() API\n next_state, reward, terminated, truncated, info = env.step(action)\n done = terminated or truncated\n\n states.append(state)\n actions.append(action)\n rewards.append(reward)\n\n state = next_state\n\n trajectories.append(Episode(states, actions, rewards))\n\n return trajectories\n\n# Function to compute returns\ndef compute_returns(trajectories, gamma=0.99):\n all_returns = []\n for episode in trajectories:\n returns = []\n G = 0\n for reward in reversed(episode.rewards):\n G = reward + gamma * G\n returns.insert(0, G)\n all_returns.extend(returns)\n return all_returns",
"filepath": "docs\\ramblings\\0.1.0\\cpbo.py",
"metadata": {
"file_path": "docs\\ramblings\\0.1.0\\cpbo.py",
"file_name": "cpbo.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 305,
"span_ids": [
"PolicyNetwork.forward",
"collect_trajectories",
"compute_returns"
],
"start_line": 25,
"end_line": 69,
"community": null
},
"node_id": "0.1.0\\cpbo.py::2"
},
{
"id": "0.1.0\\cpbo.py::3",
"input_type": "file",
"content": "# Function to create labeled dataset\ndef create_labeled_dataset(trajectories, gamma=0.99, device='cpu'):\n states = []\n actions = []\n labels = []\n\n all_returns = compute_returns(trajectories, gamma)\n all_returns = np.array(all_returns)\n median_return = np.median(all_returns)\n\n for episode in trajectories:\n for t in range(len(episode.rewards)):\n # Compute return from timestep t\n G = sum([gamma**k * episode.rewards[t + k] for k in range(len(episode.rewards) - t)])\n label = 1 if G >= median_return else 0\n states.append(episode.states[t])\n actions.append(episode.actions[t])\n labels.append(label)\n\n # Convert lists to NumPy arrays first for efficiency\n states = np.array(states)\n actions = np.array(actions)\n labels = np.array(labels)\n\n # Convert to PyTorch tensors\n states = torch.FloatTensor(states).to(device)\n actions = torch.LongTensor(actions).to(device)\n labels = torch.FloatTensor(labels).to(device)\n\n return states, actions, labels",
"filepath": "docs\\ramblings\\0.1.0\\cpbo.py",
"metadata": {
"file_path": "docs\\ramblings\\0.1.0\\cpbo.py",
"file_name": "cpbo.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 237,
"span_ids": [
"create_labeled_dataset",
"compute_returns"
],
"start_line": 71,
"end_line": 100,
"community": null
},
"node_id": "0.1.0\\cpbo.py::3"
},
{
"id": "0.1.0\\cpbo.py::4",
"input_type": "file",
"content": "# Function to perform behavioral cloning update\ndef behavioral_cloning_update(policy, optimizer, dataloader, device):\n criterion = nn.BCELoss()\n policy.train()\n\n for states, actions, labels in dataloader:\n optimizer.zero_grad()\n action_probs = policy(states)\n # Gather the probability of the taken action\n selected_probs = action_probs.gather(1, actions.unsqueeze(1)).squeeze(1)\n # Labels are 1 for good actions, 0 for bad actions\n loss = criterion(selected_probs, labels)\n loss.backward()\n optimizer.step()",
"filepath": "docs\\ramblings\\0.1.0\\cpbo.py",
"metadata": {
"file_path": "docs\\ramblings\\0.1.0\\cpbo.py",
"file_name": "cpbo.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 122,
"span_ids": [
"create_labeled_dataset",
"behavioral_cloning_update"
],
"start_line": 102,
"end_line": 115,
"community": null
},
"node_id": "0.1.0\\cpbo.py::4"
},
{
"id": "0.1.0\\cpbo.py::5",
"input_type": "file",
"content": "# Evaluation function\ndef evaluate_policy(env, policy, device, episodes=5):\n policy.eval()\n total_rewards = []\n for _ in range(episodes):\n state, info = env.reset()\n done = False\n ep_reward = 0\n while not done:\n state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)\n with torch.no_grad():\n action_probs = policy(state_tensor)\n action = torch.argmax(action_probs, dim=1).item()\n # Handle Gym's updated step() API\n next_state, reward, terminated, truncated, info = env.step(action)\n done = terminated or truncated\n ep_reward += reward\n state = next_state\n total_rewards.append(ep_reward)\n average_reward = np.mean(total_rewards)\n return average_reward",
"filepath": "docs\\ramblings\\0.1.0\\cpbo.py",
"metadata": {
"file_path": "docs\\ramblings\\0.1.0\\cpbo.py",
"file_name": "cpbo.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 167,
"span_ids": [
"behavioral_cloning_update",
"evaluate_policy"
],
"start_line": 117,
"end_line": 137,
"community": null
},
"node_id": "0.1.0\\cpbo.py::5"
},
{
"id": "0.1.0\\cpbo.py::6",
"input_type": "file",
"content": "# Main CBPO algorithm\ndef CBPO(env_name='CartPole-v1', num_epochs=10, num_episodes_per_epoch=100, gamma=0.99, \n batch_size=64, learning_rate=1e-3, device='cpu'):\n\n env = gym.make(env_name)\n state_dim = env.observation_space.shape[0]\n action_dim = env.action_space.n\n\n policy = PolicyNetwork(state_dim, action_dim).to(device)\n optimizer = optim.Adam(policy.parameters(), lr=learning_rate)\n\n for epoch in range(num_epochs):\n print(f\"Epoch {epoch+1}/{num_epochs}\")\n\n # 1. Collect trajectories\n trajectories = collect_trajectories(env, policy, num_episodes_per_epoch, device)\n\n # 2. Create labeled dataset\n states, actions, labels = create_labeled_dataset(trajectories, gamma, device)\n\n # 3. Create DataLoader\n dataset = TensorDataset(states, actions, labels)\n dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n\n # 4. Behavioral Cloning Update\n behavioral_cloning_update(policy, optimizer, dataloader, device)\n\n # 5. Evaluate current policy\n avg_reward = evaluate_policy(env, policy, device)\n print(f\"Average Reward: {avg_reward}\")\n\n # Early stopping if solved\n if avg_reward >= env.spec.reward_threshold:\n print(f\"Environment solved in {epoch+1} epochs!\")\n break\n\n env.close()\n return policy",
"filepath": "docs\\ramblings\\0.1.0\\cpbo.py",
"metadata": {
"file_path": "docs\\ramblings\\0.1.0\\cpbo.py",
"file_name": "cpbo.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 319,
"span_ids": [
"CBPO",
"evaluate_policy"
],
"start_line": 139,
"end_line": 176,
"community": null
},
"node_id": "0.1.0\\cpbo.py::6"
},
{
"id": "0.1.0\\cpbo.py::7",
"input_type": "file",
"content": "if __name__ == \"__main__\":\n # Check if GPU is available\n device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n print(f\"Using device: {device}\")\n\n # Run CBPO\n trained_policy = CBPO(\n env_name='CartPole-v1',\n num_epochs=50,\n num_episodes_per_epoch=500,\n gamma=0.99,\n batch_size=64,\n learning_rate=1e-3,\n device=device\n )\n\n # Final Evaluation\n env = gym.make('CartPole-v1')\n final_avg_reward = evaluate_policy(env, trained_policy, device, episodes=20)\n print(f\"Final Average Reward over 20 episodes: {final_avg_reward}\")\n env.close()\n\n # Save the trained policy\n torch.save(trained_policy.state_dict(), \"trained_cartpole_policy.pth\")\n print(\"Trained policy saved to trained_cartpole_policy.pth\")\n\n # Demo the trained policy with rendering\n env = gym.make('CartPole-v1', render_mode='human')\n state, _ = env.reset()\n done = False\n total_reward = 0\n\n while not done:\n state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)\n action = trained_policy(state_tensor).argmax().item()\n state, reward, terminated, truncated, _ = env.step(action)\n total_reward += reward\n done = terminated or truncated\n env.render()\n\n print(f\"Demo episode finished with total reward: {total_reward}\")\n env.close()",
"filepath": "docs\\ramblings\\0.1.0\\cpbo.py",
"metadata": {
"file_path": "docs\\ramblings\\0.1.0\\cpbo.py",
"file_name": "cpbo.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 333,
"span_ids": [
"impl"
],
"start_line": 178,
"end_line": 219,
"community": null
},
"node_id": "0.1.0\\cpbo.py::7"
},
{
"id": "0.1.0\\metapromptingtorch.py::1",
"input_type": "file",
"content": "import torch as th\n\n\nweights = th.nn.Parameter(th.randn(10))\n\n\ndef forward(x):\n return x * weights\n\n\nx = th.randn(10)\n\nprint(forward(x))\nprint(weights)\n\n# OOOH WAHT IF WE DID MANY TYPES OF LEARNABLES in",
"filepath": "docs\\ramblings\\0.1.0\\metapromptingtorch.py",
"metadata": {
"file_path": "docs\\ramblings\\0.1.0\\metapromptingtorch.py",
"file_name": "metapromptingtorch.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 57,
"span_ids": [
"forward",
"impl:3",
"imports"
],
"start_line": 3,
"end_line": 18,
"community": null
},
"node_id": "0.1.0\\metapromptingtorch.py::1"
},
{
"id": "0.1.0\\mypytest.py::1",
"input_type": "file",
"content": "from typing import TypedDict\n\n\nclass Test(TypedDict):\n name: str\n age: int\n\n\ndef test(**t: Test):\n print(t)\n\n# no type hinting like ts thats unfortunate.\ntest( )",
"filepath": "docs\\ramblings\\0.1.0\\mypytest.py",
"metadata": {
"file_path": "docs\\ramblings\\0.1.0\\mypytest.py",
"file_name": "mypytest.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 46,
"span_ids": [
"Test",
"test",
"impl",
"imports"
],
"start_line": 1,
"end_line": 14,
"community": null
},
"node_id": "0.1.0\\mypytest.py::1"
},
{
"id": "0.1.0\\test.py::1",
"input_type": "file",
"content": "from typing import Callable\n\n# The follwoing works...\n\n\n\ndef decorator(fn : Callable):\n def wrapper(*args, **kwargs):\n print(\"before\")\n result = fn(*args, **kwargs)\n print(\"after\")\n return result\n return wrapper\n\n\nclass TestCallable:\n def __init__(self, fn : Callable):\n self.fn = fn\n\n def __call__(self, *args, **kwargs):\n return self.fn(*args, **kwargs)\n\ndef convert_to_test_callable(fn : Callable):\n return TestCallable(fn)\n\nx = TestCallable(lambda : 1)\n\n@decorator\n@convert_to_test_callable\ndef test():\n print(\"test\")\n\n@decorator\nclass MyCallable:\n def __init__(self, fn : Callable):\n self.fn = fn\n\n def __call__(self, *args, **kwargs):\n return self.fn(*args, **kwargs)\n\n# Oh so now ell.simples can actually be used as decorators on classes\r",
"filepath": "docs\\ramblings\\0.1.0\\test.py",
"metadata": {
"file_path": "docs\\ramblings\\0.1.0\\test.py",
"file_name": "test.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 208,
"span_ids": [
"impl",
"test",
"MyCallable",
"TestCallable.__init__",
"MyCallable.__call__",
"decorator",
"MyCallable.__init__",
"TestCallable",
"imports",
"TestCallable.__call__",
"convert_to_test_callable"
],
"start_line": 2,
"end_line": 43,
"community": null
},
"node_id": "0.1.0\\test.py::1"
},
{
"id": "src\\conf.py::1",
"input_type": "file",
"content": "# Configuration file for the Sphinx documentation builder.\n#\n# For the full list of built-in configuration values, see the documentation:\n# https://www.sphinx-doc.org/en/master/usage/configuration.html\n\n# -- Project information -----------------------------------------------------\n# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information\n\nproject = 'ell'\ncopyright = '2024, William Guss'\nauthor = 'William Guss'\n\n# -- General configuration ---------------------------------------------------\n# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration\nextensions = ['sphinx.ext.autodoc', 'sphinx.ext.napoleon', 'sphinxawesome_theme', 'sphinxcontrib.autodoc_pydantic']\n\ntemplates_path = ['_templates']\nexclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']\n\nhtml_theme = \"sphinxawesome_theme\"\n\n\n# Favicon configuration\nhtml_favicon = '_static/favicon.ico'\n\n# Configure syntax highlighting for Awesome Sphinx Theme\npygments_style = \"default\"\npygments_style_dark = \"dracula\"\n\n# Additional theme configuration\r",
"filepath": "docs\\src\\conf.py",
"metadata": {
"file_path": "docs\\src\\conf.py",
"file_name": "conf.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 223,
"span_ids": [
"docstring"
],
"start_line": 1,
"end_line": 30,
"community": null
},
"node_id": "src\\conf.py::1"
},
{
"id": "src\\conf.py::2",
"input_type": "file",
"content": "html_theme_options = {\n \"show_prev_next\": True,\n \"show_scrolltop\": True,\n \"main_nav_links\": {\n \"Docs\": \"index\",\n \"API Reference\": \"reference/index\",\n \"AI Jobs Board\": \"https://jobs.ell.so\",\n },\n \"extra_header_link_icons\": {\n \"Discord\": {\n \"link\": \"https://discord.gg/vWntgU52Xb\",\n \"icon\": \"\"\"<svg xmlns=\"http://www.w3.org/2000/svg\" viewBox=\"0 0 640 512\" height=\"18\" fill=\"currentColor\"><!--!Font Awesome Free 6.6.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free Copyright 2024 Fonticons, Inc.--><path d=\"M524.5 69.8a1.5 1.5 0 0 0 -.8-.7A485.1 485.1 0 0 0 404.1 32a1.8 1.8 0 0 0 -1.9 .9 337.5 337.5 0 0 0 -14.9 30.6 447.8 447.8 0 0 0 -134.4 0 309.5 309.5 0 0 0 -15.1-30.6 1.9 1.9 0 0 0 -1.9-.9A483.7 483.7 0 0 0 116.1 69.1a1.7 1.7 0 0 0 -.8 .7C39.1 183.7 18.2 294.7 28.4 404.4a2 2 0 0 0 .8 1.4A487.7 487.7 0 0 0 176 479.9a1.9 1.9 0 0 0 2.1-.7A348.2 348.2 0 0 0 208.1 430.4a1.9 1.9 0 0 0 -1-2.6 321.2 321.2 0 0 1 -45.9-21.9 1.9 1.9 0 0 1 -.2-3.1c3.1-2.3 6.2-4.7 9.1-7.1a1.8 1.8 0 0 1 1.9-.3c96.2 43.9 200.4 43.9 295.5 0a1.8 1.8 0 0 1 1.9 .2c2.9 2.4 6 4.9 9.1 7.2a1.9 1.9 0 0 1 -.2 3.1 301.4 301.4 0 0 1 -45.9 21.8 1.9 1.9 0 0 0 -1 2.6 391.1 391.1 0 0 0 30 48.8 1.9 1.9 0 0 0 2.1 .7A486 486 0 0 0 610.7 405.7a1.9 1.9 0 0 0 .8-1.4C623.7 277.6 590.9 167.5 524.5 69.8zM222.5 337.6c-29 0-52.8-26.6-52.8-59.2S193.1 219.1 222.5 219.1c29.7 0 53.3 26.8 52.8 59.2C275.3 311 251.9 337.6 222.5 337.6zm195.4 0c-29 0-52.8-26.6-52.8-59.2S388.4 219.1 417.9 219.1c29.7 0 53.3 26.8 52.8 59.2C470.7 311 447.5 337.6 417.9 337.6z\"/></svg>\"\"\",\n \"type\": \"font-awesome\",\n \"name\": \"Discord\",\n },\n },\n\n \"logo_light\": \"_static/ell-wide-light.png\",\n \"logo_dark\": \"_static/ell-wide-dark.png\",\n \n}\n\nhtml_static_path = ['_static']\n\n\n\ntemplates_path = ['_templates']",
"filepath": "docs\\src\\conf.py",
"metadata": {
"file_path": "docs\\src\\conf.py",
"file_name": "conf.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 1044,
"span_ids": [
"impl:23",
"docstring"
],
"start_line": 31,
"end_line": 58,
"community": null
},
"node_id": "src\\conf.py::2"
},
{
"id": "ell\\__init__.py::1",
"input_type": "file",
"content": "\"\"\"\nell is a Python library for language model programming (LMP). It provides a simple\nand intuitive interface for working with large language models.\n\"\"\"\n\n\nfrom ell.lmp.simple import simple\nfrom ell.lmp.tool import tool\nfrom ell.lmp.complex import complex\nfrom ell.types.message import system, user, assistant, Message, ContentBlock\nfrom ell.__version__ import __version__\n\n# Import all models\nimport ell.providers\nimport ell.models\n\n\n# Import everything from configurator\nfrom ell.configurator import *",
"filepath": "src\\ell\\__init__.py",
"metadata": {
"file_path": "src\\ell\\__init__.py",
"file_name": "__init__.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 106,
"span_ids": [
"docstring"
],
"start_line": 1,
"end_line": 20,
"community": null
},
"node_id": "ell\\__init__.py::1"
},
{
"id": "ell\\__version__.py::1",
"input_type": "file",
"content": "try:\n from importlib.metadata import version\nexcept ImportError:\n from importlib_metadata import version\n\n__version__ = version(\"ell-ai\")",
"filepath": "src\\ell\\__version__.py",
"metadata": {
"file_path": "src\\ell\\__version__.py",
"file_name": "__version__.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 31,
"span_ids": [
"impl"
],
"start_line": 1,
"end_line": 7,
"community": null
},
"node_id": "ell\\__version__.py::1"
},
{
"id": "ell\\configurator.py::1",
"input_type": "file",
"content": "from functools import lru_cache, wraps\nfrom typing import Dict, Any, Optional, Tuple, Union, Type\nimport openai\nimport logging\nfrom contextlib import contextmanager\nimport threading\nfrom pydantic import BaseModel, ConfigDict, Field\nfrom ell.store import Store\nfrom ell.provider import Provider\nfrom dataclasses import dataclass, field\n\n_config_logger = logging.getLogger(__name__)\n\n@dataclass(frozen=True)\nclass _Model:\n name: str\n default_client: Optional[Union[openai.Client, Any]] = None\n #XXX: Deprecation in 0.1.0\n #XXX: We will depreciate this when streaming is implemented. \n # Currently we stream by default for the verbose renderer,\n # but in the future we will not support streaming by default \n # and stream=True must be passed which will then make API providers the\n # single source of truth for whether or not a model supports an api parameter.\n # This makes our implementation extremely light, only requiring us to provide\n # a list of model names in registration.\n supports_streaming : Optional[bool] = field(default=None)",
"filepath": "src\\ell\\configurator.py",
"metadata": {
"file_path": "src\\ell\\configurator.py",
"file_name": "configurator.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 243,
"span_ids": [
"imports",
"_Model"
],
"start_line": 1,
"end_line": 26,
"community": null
},
"node_id": "ell\\configurator.py::1"
},
{
"id": "ell\\configurator.py::2",
"input_type": "file",
"content": "class Config(BaseModel):\n model_config = ConfigDict(arbitrary_types_allowed=True)\n registry: Dict[str, _Model] = Field(default_factory=dict, description=\"A dictionary mapping model names to their configurations.\")\n verbose: bool = Field(default=False, description=\"If True, enables verbose logging.\")\n wrapped_logging: bool = Field(default=True, description=\"If True, enables wrapped logging for better readability.\")\n override_wrapped_logging_width: Optional[int] = Field(default=None, description=\"If set, overrides the default width for wrapped logging.\")\n store: Optional[Store] = Field(default=None, description=\"An optional Store instance for persistence.\")\n autocommit: bool = Field(default=False, description=\"If True, enables automatic committing of changes to the store.\")\n lazy_versioning: bool = Field(default=True, description=\"If True, enables lazy versioning for improved performance.\")\n default_api_params: Dict[str, Any] = Field(default_factory=dict, description=\"Default parameters for language models.\")\n default_client: Optional[openai.Client] = Field(default=None, description=\"The default OpenAI client used when a specific model client is not found.\")\n autocommit_model: str = Field(default=\"gpt-4o-mini\", description=\"When set, changes the default autocommit model from GPT 4o mini.\")\n providers: Dict[Type, Provider] = Field(default_factory=dict, description=\"A dictionary mapping client types to provider classes.\")\n def __init__(self, **data):\n super().__init__(**data)\n self._lock = threading.Lock()\n self._local = threading.local()",
"filepath": "src\\ell\\configurator.py",
"metadata": {
"file_path": "src\\ell\\configurator.py",
"file_name": "configurator.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 335,
"span_ids": [
"Config.__init__",
"Config"
],
"start_line": 30,
"end_line": 46,
"community": null
},
"node_id": "ell\\configurator.py::2"
},
{
"id": "ell\\configurator.py::3",
"input_type": "file",
"content": "class Config(BaseModel):\n\n\n def register_model(\n self, \n name: str,\n default_client: Optional[Union[openai.Client, Any]] = None,\n supports_streaming: Optional[bool] = None\n ) -> None:\n \"\"\"\n Register a model with its configuration.\n \"\"\"\n with self._lock:\n # XXX: Will be deprecated in 0.1.0\n self.registry[name] = _Model(\n name=name,\n default_client=default_client,\n supports_streaming=supports_streaming\n )",
"filepath": "src\\ell\\configurator.py",
"metadata": {
"file_path": "src\\ell\\configurator.py",
"file_name": "configurator.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 118,
"span_ids": [
"Config.register_model"
],
"start_line": 49,
"end_line": 64,
"community": null
},
"node_id": "ell\\configurator.py::3"
},
{
"id": "ell\\configurator.py::4",
"input_type": "file",
"content": "class Config(BaseModel):\n\n\n\n @contextmanager\n def model_registry_override(self, overrides: Dict[str, _Model]):\n \"\"\"\n Temporarily override the model registry with new model configurations.\n\n :param overrides: A dictionary of model names to ModelConfig instances to override.\n :type overrides: Dict[str, ModelConfig]\n \"\"\"\n if not hasattr(self._local, 'stack'):\n self._local.stack = []\n\n with self._lock:\n current_registry = self._local.stack[-1] if self._local.stack else self.registry\n new_registry = current_registry.copy()\n new_registry.update(overrides)\n\n self._local.stack.append(new_registry)\n try:\n yield\n finally:\n self._local.stack.pop()",
"filepath": "src\\ell\\configurator.py",
"metadata": {
"file_path": "src\\ell\\configurator.py",
"file_name": "configurator.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 154,
"span_ids": [
"Config.model_registry_override"
],
"start_line": 68,
"end_line": 88,
"community": null
},
"node_id": "ell\\configurator.py::4"
},
{
"id": "ell\\configurator.py::5",
"input_type": "file",
"content": "class Config(BaseModel):\n\n def get_client_for(self, model_name: str) -> Tuple[Optional[openai.Client], bool]:\n \"\"\"\n Get the OpenAI client for a specific model name.\n\n :param model_name: The name of the model to get the client for.\n :type model_name: str\n :return: The OpenAI client for the specified model, or None if not found, and a fallback flag.\n :rtype: Tuple[Optional[openai.Client], bool]\n \"\"\"\n current_registry = self._local.stack[-1] if hasattr(self._local, 'stack') and self._local.stack else self.registry\n model_config = current_registry.get(model_name)\n fallback = False\n if not model_config:\n warning_message = f\"Warning: A default provider for model '{model_name}' could not be found. Falling back to default OpenAI client from environment variables.\"\n if self.verbose:\n from colorama import Fore, Style\n _config_logger.warning(f\"{Fore.LIGHTYELLOW_EX}{warning_message}{Style.RESET_ALL}\")\n else:\n _config_logger.debug(warning_message)\n client = self.default_client\n fallback = True\n else:\n client = model_config.default_client\n return client, fallback",
"filepath": "src\\ell\\configurator.py",
"metadata": {
"file_path": "src\\ell\\configurator.py",
"file_name": "configurator.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 264,
"span_ids": [
"Config.get_client_for"
],
"start_line": 90,
"end_line": 113,
"community": null
},
"node_id": "ell\\configurator.py::5"
},
{
"id": "ell\\configurator.py::6",
"input_type": "file",
"content": "class Config(BaseModel):\n\n def register_provider(self, provider: Provider, client_type: Type[Any]) -> None:\n \"\"\"\n Register a provider class for a specific client type.\n\n :param provider_class: The provider class to register.\n :type provider_class: Type[Provider]\n \"\"\"\n assert isinstance(client_type, type), \"client_type must be a type (e.g. openai.Client), not an an instance (myclient := openai.Client()))\"\n with self._lock:\n self.providers[client_type] = provider",
"filepath": "src\\ell\\configurator.py",
"metadata": {
"file_path": "src\\ell\\configurator.py",
"file_name": "configurator.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 112,
"span_ids": [
"Config.register_provider"
],
"start_line": 115,
"end_line": 124,
"community": null
},
"node_id": "ell\\configurator.py::6"
},
{
"id": "ell\\configurator.py::7",
"input_type": "file",
"content": "class Config(BaseModel):\n\n def get_provider_for(self, client: Union[Type[Any], Any]) -> Optional[Provider]:\n \"\"\"\n Get the provider instance for a specific client instance.\n\n :param client: The client instance to get the provider for.\n :type client: Any\n :return: The provider instance for the specified client, or None if not found.\n :rtype: Optional[Provider]\n \"\"\"\n\n client_type = type(client) if not isinstance(client, type) else client\n for provider_type, provider in self.providers.items():\n if issubclass(client_type, provider_type) or client_type == provider_type:\n return provider\n return None\n\n# Single* instance\n# XXX: Make a singleton\nconfig = Config()",
"filepath": "src\\ell\\configurator.py",
"metadata": {
"file_path": "src\\ell\\configurator.py",
"file_name": "configurator.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 157,
"span_ids": [
"Config.get_provider_for",
"impl:3"
],
"start_line": 126,
"end_line": 144,
"community": null
},
"node_id": "ell\\configurator.py::7"
},
{
"id": "ell\\configurator.py::8",
"input_type": "file",
"content": "def init(\n store: Optional[Union[Store, str]] = None,\n verbose: bool = False,\n autocommit: bool = True,\n lazy_versioning: bool = True,\n default_api_params: Optional[Dict[str, Any]] = None,\n default_client: Optional[Any] = None,\n autocommit_model: str = \"gpt-4o-mini\"\n) -> None:\n \"\"\"\n Initialize the ELL configuration with various settings.\n\n :param verbose: Set verbosity of ELL operations.\n :type verbose: bool\n :param store: Set the store for ELL. Can be a Store instance or a string path for SQLiteStore.\n :type store: Union[Store, str], optional\n :param autocommit: Set autocommit for the store operations.\n :type autocommit: bool\n :param lazy_versioning: Enable or disable lazy versioning.\n :type lazy_versioning: bool\n :param default_api_params: Set default parameters for language models.\n :type default_api_params: Dict[str, Any], optional\n :param default_openai_client: Set the default OpenAI client.\n :type default_openai_client: openai.Client, optional\n :param autocommit_model: Set the model used for autocommitting.\n :type autocommit_model: str\n \"\"\"\n # XXX: prevent double init\n config.verbose = verbose\n config.lazy_versioning = lazy_versioning\n\n if isinstance(store, str):\n from ell.stores.sql import SQLiteStore\n config.store = SQLiteStore(store)\n else:\n config.store = store\n config.autocommit = autocommit or config.autocommit\n\n if default_api_params is not None:\n config.default_api_params.update(default_api_params)\n\n if default_client is not None:\n config.default_client = default_client\n\n if autocommit_model is not None:\n config.autocommit_model = autocommit_model",
"filepath": "src\\ell\\configurator.py",
"metadata": {
"file_path": "src\\ell\\configurator.py",
"file_name": "configurator.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 420,
"span_ids": [
"init"
],
"start_line": 146,
"end_line": 191,
"community": null
},
"node_id": "ell\\configurator.py::8"
},
{
"id": "ell\\configurator.py::9",
"input_type": "file",
"content": "# Existing helper functions\ndef get_store() -> Union[Store, None]:\n return config.store\n\n# Will be deprecated at 0.1.0 \n\n# You can add more helper functions here if needed\ndef register_provider(provider: Provider, client_type: Type[Any]) -> None:\n return config.register_provider(provider, client_type)\n\n# Deprecated now (remove at 0.1.0)\ndef set_store(*args, **kwargs) -> None:\n raise DeprecationWarning(\"The set_store function is deprecated and will be removed in a future version. Use ell.init(store=...) instead.\")",
"filepath": "src\\ell\\configurator.py",
"metadata": {
"file_path": "src\\ell\\configurator.py",
"file_name": "configurator.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 126,
"span_ids": [
"get_store",
"set_store",
"init",
"register_provider"
],
"start_line": 193,
"end_line": 205,
"community": null
},
"node_id": "ell\\configurator.py::9"
},
{
"id": "lmp\\__init__.py::1",
"input_type": "file",
"content": "from ell.lmp.simple import simple\nfrom ell.lmp.complex import complex",
"filepath": "src\\ell\\lmp\\__init__.py",
"metadata": {
"file_path": "src\\ell\\lmp\\__init__.py",
"file_name": "__init__.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 16,
"span_ids": [
"imports"
],
"start_line": 1,
"end_line": 2,
"community": null
},
"node_id": "lmp\\__init__.py::1"
},
{
"id": "lmp\\_track.py::1",
"input_type": "file",
"content": "import json\nimport logging\nimport threading\nfrom ell.types import SerializedLMP, Invocation, InvocationTrace, InvocationContents\nfrom ell.types.studio import LMPType, utc_now\nfrom ell.util._warnings import _autocommit_warning\nimport ell.util.closure\nfrom ell.configurator import config\nfrom ell.types._lstr import _lstr\n\nimport inspect\n\nimport secrets\nimport time\nfrom datetime import datetime\nfrom functools import wraps\nfrom typing import Any, Callable, Dict, Iterable, Optional, OrderedDict, Tuple\n\nfrom ell.util.serialization import get_immutable_vars\nfrom ell.util.serialization import compute_state_cache_key\nfrom ell.util.serialization import prepare_invocation_params\n\nlogger = logging.getLogger(__name__)\n\n# Thread-local storage for the invocation stack\n_invocation_stack = threading.local()\n\ndef get_current_invocation() -> Optional[str]:\n if not hasattr(_invocation_stack, 'stack'):\n _invocation_stack.stack = []\n return _invocation_stack.stack[-1] if _invocation_stack.stack else None\n\ndef push_invocation(invocation_id: str):\n if not hasattr(_invocation_stack, 'stack'):\n _invocation_stack.stack = []\n _invocation_stack.stack.append(invocation_id)\n\ndef pop_invocation():\n if hasattr(_invocation_stack, 'stack') and _invocation_stack.stack:\n _invocation_stack.stack.pop()",
"filepath": "src\\ell\\lmp\\_track.py",
"metadata": {
"file_path": "src\\ell\\lmp\\_track.py",
"file_name": "_track.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 284,
"span_ids": [
"push_invocation",
"get_current_invocation",
"imports",
"pop_invocation"
],
"start_line": 1,
"end_line": 40,
"community": null
},
"node_id": "lmp\\_track.py::1"
},
{
"id": "lmp\\_track.py::2",
"input_type": "file",
"content": "def _track(func_to_track: Callable, *, forced_dependencies: Optional[Dict[str, Any]] = None) -> Callable:\n\n lmp_type = getattr(func_to_track, \"__ell_type__\", LMPType.OTHER)\n\n\n # see if it exists\n if not hasattr(func_to_track, \"_has_serialized_lmp\"):\n func_to_track._has_serialized_lmp = False\n\n if not hasattr(func_to_track, \"__ell_hash__\") and not config.lazy_versioning:\n ell.util.closure.lexically_closured_source(func_to_track, forced_dependencies)\n\n\n @wraps(func_to_track)\n def tracked_func(*fn_args, _get_invocation_id=False, **fn_kwargs) -> str:\n # XXX: Cache keys and global variable binding is not thread safe.\n # Compute the invocation id and hash the inputs for serialization.\n invocation_id = \"invocation-\" + secrets.token_hex(16)\n\n state_cache_key : str = None\n if not config.store:\n return func_to_track(*fn_args, **fn_kwargs, _invocation_origin=invocation_id)[0]\n\n parent_invocation_id = get_current_invocation()\n try:\n push_invocation(invocation_id)\n\n # Convert all positional arguments to named keyword arguments\n sig = inspect.signature(func_to_track)\n # Filter out kwargs that are not in the function signature\n filtered_kwargs = {k: v for k, v in fn_kwargs.items() if k in sig.parameters}\n\n bound_args = sig.bind(*fn_args, **filtered_kwargs)\n bound_args.apply_defaults()\n all_kwargs = dict(bound_args.arguments)\n\n # Get the list of consumed lmps and clean the invocation params for serialization.\n cleaned_invocation_params, ipstr, consumes = prepare_invocation_params( all_kwargs)\n\n try_use_cache = hasattr(func_to_track.__wrapper__, \"__ell_use_cache__\")\n\n if try_use_cache:\n # Todo: add nice logging if verbose for when using a cahced invocaiton. IN a different color with thar args..\n if not hasattr(func_to_track, \"__ell_hash__\") and config.lazy_versioning:\n fn_closure, _ = ell.util.closure.lexically_closured_source(func_to_track)\n\n # compute the state cachekey\n state_cache_key = compute_state_cache_key(ipstr, func_to_track.__ell_closure__)\n\n cache_store = func_to_track.__wrapper__.__ell_use_cache__\n cached_invocations = cache_store.get_cached_invocations(func_to_track.__ell_hash__, state_cache_key)\n\n\n if len(cached_invocations) > 0:\n # XXX: Fix caching.\n results = [d.deserialize() for d in cached_invocations[0].results]\n\n logger.info(f\"Using cached result for {func_to_track.__qualname__} with state cache key: {state_cache_key}\")\n if len(results) == 1:\n return results[0]\n else:\n return results\n # Todo: Unfiy this with the non-cached case. We should go through the same code pathway.\n else:\n logger.info(f\"Attempted to use cache on {func_to_track.__qualname__} but it was not cached, or did not exist in the store. Refreshing cache...\")\n\n\n _start_time = utc_now()\n\n # XXX: thread saftey note, if I prevent yielding right here and get the global context I should be fine re: cache key problem\n\n # get the prompt\n (result, invocation_api_params, metadata) = (\n (func_to_track(*fn_args, **fn_kwargs), {}, {})\n if lmp_type == LMPType.OTHER\n else func_to_track(*fn_args, _invocation_origin=invocation_id, **fn_kwargs, )\n )\n latency_ms = (utc_now() - _start_time).total_seconds() * 1000\n usage = metadata.get(\"usage\", {\"prompt_tokens\": 0, \"completion_tokens\": 0})\n prompt_tokens= usage.get(\"prompt_tokens\", 0) if usage else 0\n completion_tokens= usage.get(\"completion_tokens\", 0) if usage else 0\n\n\n #XXX: cattrs add invocation origin here recursively on all pirmitive types within a message.\n #XXX: This will allow all objects to be traced automatically irrespective origin rather than relying on the API to do it, it will of vourse be expensive but unify track.\n #XXX: No other code will need to consider tracking after this point.\n\n if not hasattr(func_to_track, \"__ell_hash__\") and config.lazy_versioning:\n ell.util.closure.lexically_closured_source(func_to_track, forced_dependencies)\n _serialize_lmp(func_to_track)\n\n if not state_cache_key:\n state_cache_key = compute_state_cache_key(ipstr, func_to_track.__ell_closure__)\n\n _write_invocation(func_to_track, invocation_id, latency_ms, prompt_tokens, completion_tokens, \n state_cache_key, invocation_api_params, cleaned_invocation_params, consumes, result, parent_invocation_id)\n\n if _get_invocation_id:\n return result, invocation_id\n else:\n return result\n finally:\n pop_invocation()\n\n\n func_to_track.__wrapper__ = tracked_func\n if hasattr(func_to_track, \"__ell_api_params__\"):\n tracked_func.__ell_api_params__ = func_to_track.__ell_api_params__\n if hasattr(func_to_track, \"__ell_params_model__\"):\n tracked_func.__ell_params_model__ = func_to_track.__ell_params_model__\n tracked_func.__ell_func__ = func_to_track\n tracked_func.__ell_track = True\n\n return tracked_func",
"filepath": "src\\ell\\lmp\\_track.py",
"metadata": {
"file_path": "src\\ell\\lmp\\_track.py",
"file_name": "_track.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 1199,
"span_ids": [
"_track"
],
"start_line": 43,
"end_line": 156,
"community": null
},
"node_id": "lmp\\_track.py::2"
},
{
"id": "lmp\\_track.py::3",
"input_type": "file",
"content": "def _serialize_lmp(func):\n # Serialize deptjh first all fo the used lmps.\n for f in func.__ell_uses__:\n _serialize_lmp(f)\n\n if getattr(func, \"_has_serialized_lmp\", False):\n return\n func._has_serialized_lmp = False\n fn_closure = func.__ell_closure__\n lmp_type = func.__ell_type__\n name = func.__qualname__\n api_params = getattr(func, \"__ell_api_params__\", None)\n\n lmps = config.store.get_versions_by_fqn(fqn=name)\n version = 0\n already_in_store = any(lmp.lmp_id == func.__ell_hash__ for lmp in lmps)\n\n if not already_in_store:\n commit = None\n if lmps:\n latest_lmp = max(lmps, key=lambda x: x.created_at)\n version = latest_lmp.version_number + 1\n if config.autocommit:\n # XXX: Move this out to autocommit itself.\n if not _autocommit_warning():\n from ell.util.differ import write_commit_message_for_diff\n commit = str(write_commit_message_for_diff(\n f\"{latest_lmp.dependencies}\\n\\n{latest_lmp.source}\", \n f\"{fn_closure[1]}\\n\\n{fn_closure[0]}\")[0])\n\n serialized_lmp = SerializedLMP(\n lmp_id=func.__ell_hash__,\n name=name,\n created_at=utc_now(),\n source=fn_closure[0],\n dependencies=fn_closure[1],\n commit_message=commit,\n initial_global_vars=get_immutable_vars(fn_closure[2]),\n initial_free_vars=get_immutable_vars(fn_closure[3]),\n lmp_type=lmp_type,\n api_params=api_params if api_params else None,\n version_number=version,\n )\n config.store.write_lmp(serialized_lmp, [f.__ell_hash__ for f in func.__ell_uses__])\n func._has_serialized_lmp = True",
"filepath": "src\\ell\\lmp\\_track.py",
"metadata": {
"file_path": "src\\ell\\lmp\\_track.py",
"file_name": "_track.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 434,
"span_ids": [
"_serialize_lmp"
],
"start_line": 158,
"end_line": 202,
"community": null
},
"node_id": "lmp\\_track.py::3"
},
{
"id": "lmp\\_track.py::4",
"input_type": "file",
"content": "def _write_invocation(func, invocation_id, latency_ms, prompt_tokens, completion_tokens, \n state_cache_key, invocation_api_params, cleaned_invocation_params, consumes, result, parent_invocation_id):\n\n invocation_contents = InvocationContents(\n invocation_id=invocation_id,\n params=cleaned_invocation_params,\n results=result,\n invocation_api_params=invocation_api_params,\n global_vars=get_immutable_vars(func.__ell_closure__[2]),\n free_vars=get_immutable_vars(func.__ell_closure__[3])\n )\n\n if invocation_contents.should_externalize and config.store.has_blob_storage:\n invocation_contents.is_external = True\n\n # Write to the blob store \n blob_id = config.store.blob_store.store_blob(\n json.dumps(invocation_contents.model_dump(\n ), default=str, ensure_ascii=False).encode('utf-8'),\n invocation_id\n )\n invocation_contents = InvocationContents(\n invocation_id=invocation_id,\n is_external=True,\n )\n\n invocation = Invocation(\n id=invocation_id,\n lmp_id=func.__ell_hash__,\n created_at=utc_now(),\n latency_ms=latency_ms,\n prompt_tokens=prompt_tokens,\n completion_tokens=completion_tokens,\n state_cache_key=state_cache_key,\n used_by_id=parent_invocation_id,\n contents=invocation_contents\n )\n\n config.store.write_invocation(invocation, consumes)",
"filepath": "src\\ell\\lmp\\_track.py",
"metadata": {
"file_path": "src\\ell\\lmp\\_track.py",
"file_name": "_track.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 296,
"span_ids": [
"_write_invocation"
],
"start_line": 204,
"end_line": 244,
"community": null
},
"node_id": "lmp\\_track.py::4"
},
{
"id": "lmp\\complex.py::1",
"input_type": "file",
"content": "from ell.configurator import config\nfrom ell.lmp._track import _track\nfrom ell.provider import EllCallParams\nfrom ell.types._lstr import _lstr\nfrom ell.types import Message, ContentBlock\nfrom ell.types.message import LMP, InvocableLM, LMPParams, MessageOrDict, _lstr_generic\nfrom ell.types.studio import LMPType\nfrom ell.util._warnings import _no_api_key_warning, _warnings\nfrom ell.util.verbosity import compute_color, model_usage_logger_pre\n\nfrom ell.util.verbosity import model_usage_logger_post_end, model_usage_logger_post_intermediate, model_usage_logger_post_start\n\nfrom functools import wraps\nfrom typing import Any, Dict, Optional, List, Callable, Tuple, Union",
"filepath": "src\\ell\\lmp\\complex.py",
"metadata": {
"file_path": "src\\ell\\lmp\\complex.py",
"file_name": "complex.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 156,
"span_ids": [
"imports"
],
"start_line": 1,
"end_line": 14,
"community": null
},
"node_id": "lmp\\complex.py::1"
},
{
"id": "lmp\\complex.py::2",
"input_type": "file",
"content": "def complex(model: str, client: Optional[Any] = None, tools: Optional[List[Callable]] = None, exempt_from_tracking=False, post_callback: Optional[Callable] = None, **api_params):\n default_client_from_decorator = client\n default_model_from_decorator = model\n default_api_params_from_decorator = api_params\n def parameterized_lm_decorator(\n prompt: LMP,\n ) -> Callable[..., Union[List[Message], Message]]:\n _warnings(model, prompt, default_client_from_decorator)\n\n @wraps(prompt)\n def model_call(\n *prompt_args,\n _invocation_origin : Optional[str] = None,\n client: Optional[Any] = None,\n api_params: Optional[Dict[str, Any]] = None,\n lm_params: Optional[DeprecationWarning] = None,\n **prompt_kwargs,\n ) -> Tuple[Any, Any, Any]:\n # XXX: Deprecation in 0.1.0\n if lm_params:\n raise DeprecationWarning(\"lm_params is deprecated. Use api_params instead.\")\n\n # promt -> str\n res = prompt(*prompt_args, **prompt_kwargs)\n # Convert prompt into ell messages\n messages = _get_messages(res, prompt)\n\n # XXX: move should log to a logger.\n should_log = not exempt_from_tracking and config.verbose\n # Cute verbose logging.\n if should_log: model_usage_logger_pre(prompt, prompt_args, prompt_kwargs, \"[]\", messages) #type: ignore\n\n # Call the model.\n # Merge API params\n merged_api_params = {**config.default_api_params, **default_api_params_from_decorator, **(api_params or {})}\n n = merged_api_params.get(\"n\", 1)\n # Merge client overrides & client registry\n merged_client = _client_for_model(model, client or default_client_from_decorator)\n ell_call = EllCallParams(\n # XXX: Could change behaviour of overriding ell params for dyanmic tool calls.\n model=merged_api_params.pop(\"model\", default_model_from_decorator),\n messages=messages,\n client = merged_client,\n api_params=merged_api_params,\n tools=tools or [],\n )\n # Get the provider for the model\n provider = config.get_provider_for(ell_call.client)\n assert provider is not None, f\"No provider found for client {ell_call.client}.\"\n\n if should_log: model_usage_logger_post_start(n)\n with model_usage_logger_post_intermediate(n) as _logger:\n (result, final_api_params, metadata) = provider.call(ell_call, origin_id=_invocation_origin, logger=_logger if should_log else None)\n if isinstance(result, list) and len(result) == 1:\n result = result[0]\n\n result = post_callback(result) if post_callback else result\n if should_log:\n model_usage_logger_post_end()\n #\n # These get sent to track. This is wack. \n return result, final_api_params, metadata\n # ... other code\n # ... other code",
"filepath": "src\\ell\\lmp\\complex.py",
"metadata": {
"file_path": "src\\ell\\lmp\\complex.py",
"file_name": "complex.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 645,
"span_ids": [
"complex"
],
"start_line": 16,
"end_line": 77,
"community": null
},
"node_id": "lmp\\complex.py::2"
},
{
"id": "lmp\\complex.py::3",
"input_type": "file",
"content": "def complex(model: str, client: Optional[Any] = None, tools: Optional[List[Callable]] = None, exempt_from_tracking=False, post_callback: Optional[Callable] = None, **api_params):\n def parameterized_lm_decorator(\n prompt: LMP,\n ) -> Callable[..., Union[List[Message], Message]]:\n # ... other code\n\n\n\n model_call.__ell_api_params__ = default_api_params_from_decorator #type: ignore\n model_call.__ell_func__ = prompt #type: ignore\n model_call.__ell_type__ = LMPType.LM #type: ignore\n model_call.__ell_exempt_from_tracking = exempt_from_tracking #type: ignore\n\n\n if exempt_from_tracking:\n return model_call\n else:\n # XXX: Analyze decorators with AST instead.\n return _track(model_call, forced_dependencies=dict(tools=tools, response_format=api_params.get(\"response_format\", {})))\n return parameterized_lm_decorator",
"filepath": "src\\ell\\lmp\\complex.py",
"metadata": {
"file_path": "src\\ell\\lmp\\complex.py",
"file_name": "complex.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 201,
"span_ids": [
"complex"
],
"start_line": 81,
"end_line": 92,
"community": null
},
"node_id": "lmp\\complex.py::3"
},
{
"id": "lmp\\complex.py::4",
"input_type": "file",
"content": "def _get_messages(prompt_ret: Union[str, list[MessageOrDict]], prompt: LMP) -> list[Message]:\n \"\"\"\n Helper function to convert the output of an LMP into a list of Messages.\n \"\"\"\n if isinstance(prompt_ret, str):\n has_system_prompt = prompt.__doc__ is not None and prompt.__doc__.strip() != \"\"\n messages = [Message(role=\"system\", content=[ContentBlock(text=_lstr(prompt.__doc__ ) )])] if has_system_prompt else []\n return messages + [\n Message(role=\"user\", content=[ContentBlock(text=prompt_ret)])\n ]\n else:\n assert isinstance(\n prompt_ret, list\n ), \"Need to pass a list of Messages to the language model\"\n return prompt_ret",
"filepath": "src\\ell\\lmp\\complex.py",
"metadata": {
"file_path": "src\\ell\\lmp\\complex.py",
"file_name": "complex.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 162,
"span_ids": [
"_get_messages"
],
"start_line": 96,
"end_line": 110,
"community": null
},
"node_id": "lmp\\complex.py::4"
},
{
"id": "lmp\\complex.py::5",
"input_type": "file",
"content": "def _client_for_model(\n model: str,\n client: Optional[Any] = None,\n _name: Optional[str] = None,\n) -> Any:\n # XXX: Move to config to centralize api keys etc.\n if not client:\n client, was_fallback = config.get_client_for(model)\n\n # XXX: Wrong.\n if not client and not was_fallback:\n raise RuntimeError(_no_api_key_warning(model, _name, '', long=True, error=True))\n\n if client is None:\n raise ValueError(f\"No client found for model '{model}'. Ensure the model is registered using 'register_model' in 'config.py' or specify a client directly using the 'client' argument in the decorator or function call.\")\n return client\n\n\ncomplex.__doc__ =\n # ... other code",
"filepath": "src\\ell\\lmp\\complex.py",
"metadata": {
"file_path": "src\\ell\\lmp\\complex.py",
"file_name": "complex.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 166,
"span_ids": [
"impl",
"_client_for_model"
],
"start_line": 112,
"end_line": 327,
"community": null
},
"node_id": "lmp\\complex.py::5"
},
{
"id": "lmp\\simple.py::1",
"input_type": "file",
"content": "from functools import wraps\nfrom typing import Any, Optional\n\nfrom ell.lmp.complex import complex\n\n\ndef simple(model: str, client: Optional[Any] = None, exempt_from_tracking=False, **api_params):\n assert 'tools' not in api_params, \"tools are not supported in lm decorator, use multimodal decorator instead\"\n assert 'tool_choice' not in api_params, \"tool_choice is not supported in lm decorator, use multimodal decorator instead\"\n assert 'response_format' not in api_params or isinstance(api_params.get('response_format', None), dict), \"response_format is not supported in lm decorator, use multimodal decorator instead\"\n\n def convert_multimodal_response_to_lstr(response):\n return [x.content[0].text for x in response] if isinstance(response, list) else response.content[0].text\n return complex(model, client, exempt_from_tracking=exempt_from_tracking, **api_params, post_callback=convert_multimodal_response_to_lstr)",
"filepath": "src\\ell\\lmp\\simple.py",
"metadata": {
"file_path": "src\\ell\\lmp\\simple.py",
"file_name": "simple.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 209,
"span_ids": [
"simple",
"imports"
],
"start_line": 1,
"end_line": 14,
"community": null
},
"node_id": "lmp\\simple.py::1"
},
{
"id": "lmp\\simple.py::2",
"input_type": "file",
"content": "simple.__doc__ = \"\"\"The fundamental unit of language model programming in ell.\n\n This decorator simplifies the process of creating Language Model Programs (LMPs) \n that return text-only outputs from language models, while supporting multimodal inputs.\n It wraps the more complex 'complex' decorator, providing a streamlined interface for common use cases.\n\n :param model: The name or identifier of the language model to use.\n :type model: str\n :param client: An optional OpenAI client instance. If not provided, a default client will be used.\n :type client: Optional[openai.Client]\n :param exempt_from_tracking: If True, the LMP usage won't be tracked. Default is False.\n :type exempt_from_tracking: bool\n :param api_params: Additional keyword arguments to pass to the underlying API call.\n :type api_params: Any\n\n Usage:\n The decorated function can return either a single prompt or a list of ell.Message objects:\n\n .. code-block:: python\n\n @ell.simple(model=\"gpt-4\", temperature=0.7)\n def summarize_text(text: str) -> str:\n '''You are an expert at summarizing text.''' # System prompt\n return f\"Please summarize the following text:\\\\n\\\\n{text}\" # User prompt\n\n\n @ell.simple(model=\"gpt-4\", temperature=0.7)\n def describe_image(image : PIL.Image.Image) -> List[ell.Message]:\n '''Describe the contents of an image.''' # unused because we're returning a list of Messages\n return [\n # helper function for ell.Message(text=\"...\", role=\"system\")\n ell.system(\"You are an AI trained to describe images.\"),\n # helper function for ell.Message(content=\"...\", role=\"user\")\n ell.user([\"Describe this image in detail.\", image]),\n ]\n\n\n image_description = describe_image(PIL.Image.open(\"https://example.com/image.jpg\"))\n print(image_description) \n # Output will be a string text-only description of the image\n\n summary = summarize_text(\"Long text to summarize...\")\n print(summary)\n # Output will be a text-only summary\n\n Notes:\n\n - This decorator is designed for text-only model outputs, but supports multimodal inputs.\n - It simplifies complex responses from language models to text-only format, regardless of \n the model's capability for structured outputs, function calling, or multimodal outputs.\n - For preserving complex model outputs (e.g., structured data, function calls, or multimodal \n outputs), use the @ell.complex decorator instead. @ell.complex returns a Message object (role='assistant')\n - The decorated function can return a string or a list of ell.Message objects for more \n complex prompts, including multimodal inputs.\n - If called with n > 1 in api_params, the wrapped LMP will return a list of strings for the n parallel outputs\n of the model instead of just one string. Otherwise, it will return a single string.\n - You can pass LM API parameters either in the decorator or when calling the decorated function.\n Parameters passed during the function call will override those set in the decorator.\n\n Example of passing LM API params:\n\n .. code-block:: python\n\n @ell.simple(model=\"gpt-4\", temperature=0.7)\n def generate_story(prompt: str) -> str:\n return f\"Write a short story based on this prompt: {prompt}\"\n\n # Using default parameters\n story1 = generate_story(\"A day in the life of a time traveler\")\n\n # Overriding parameters during function call\n story2 = generate_story(\"An AI's first day of consciousness\", api_params={\"temperature\": 0.9, \"max_tokens\": 500})\n\n See Also:\n\n - :func:`ell.complex`: For LMPs that preserve full structure of model responses, including multimodal outputs.\n - :func:`ell.tool`: For defining tools that can be used within complex LMPs.\n - :mod:`ell.studio`: For visualizing and analyzing LMP executions.\n \"\"\"",
"filepath": "src\\ell\\lmp\\simple.py",
"metadata": {
"file_path": "src\\ell\\lmp\\simple.py",
"file_name": "simple.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 861,
"span_ids": [
"impl"
],
"start_line": 18,
"end_line": 96,
"community": null
},
"node_id": "lmp\\simple.py::2"
},
{
"id": "lmp\\tool.py::1",
"input_type": "file",
"content": "from functools import wraps\nimport json\nfrom typing import Any, Callable, Optional\n\nfrom pydantic import Field, create_model\nfrom pydantic.fields import FieldInfo\nfrom ell.lmp._track import _track\n# from ell.types import ToolFunction, InvocableTool, ToolParams\n# from ell.util.verbosity import compute_color, tool_usage_logger_pre\nfrom ell.configurator import config\nfrom ell.types._lstr import _lstr\nfrom ell.types.studio import LMPType\nimport inspect\n\nfrom ell.types.message import ContentBlock, InvocableTool, ToolResult, to_content_blocks",
"filepath": "src\\ell\\lmp\\tool.py",
"metadata": {
"file_path": "src\\ell\\lmp\\tool.py",
"file_name": "tool.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 127,
"span_ids": [
"imports"
],
"start_line": 1,
"end_line": 15,
"community": null
},
"node_id": "lmp\\tool.py::1"
},
{
"id": "lmp\\tool.py::2",
"input_type": "file",
"content": "def tool(*, exempt_from_tracking: bool = False, **tool_kwargs):\n def tool_decorator(fn: Callable[..., Any]) -> InvocableTool:\n _under_fn = fn\n\n @wraps(fn)\n def wrapper(\n *fn_args,\n _invocation_origin: str = None,\n _tool_call_id: str = None,\n **fn_kwargs\n ):\n #XXX: Post release, we need to wrap all tool arguments in type primitives for tracking I guess or change that tool makes the tool function inoperable.\n #XXX: Most people are not going to manually try and call the tool without a type primitive and if they do it will most likely be wrapped with l strs.\n\n if config.verbose and not exempt_from_tracking:\n pass\n # tool_usage_logger_pre(fn, fn_args, fn_kwargs, name, color)\n\n result = fn(*fn_args, **fn_kwargs)\n\n _invocation_api_params = dict(tool_kwargs=tool_kwargs)\n\n # Here you might want to add logic for tracking the tool usage\n # Similar to how it's done in the lm decorator # Use _invocation_origin\n\n if isinstance(result, str) and _invocation_origin:\n result = _lstr(result,origin_trace=_invocation_origin)\n\n #XXX: This _tool_call_id thing is a hack. Tracking should happen via params in the api\n # So if you call wiuth a _tool_callId\n if _tool_call_id:\n # XXX: TODO: MOVE TRACKING CODE TO _TRACK AND OUT OF HERE AND API.\n try:\n if isinstance(result, ContentBlock):\n content_results = [result]\n elif isinstance(result, list) and all(isinstance(c, ContentBlock) for c in result):\n content_results = result\n else:\n content_results = [ContentBlock(text=_lstr(json.dumps(result, ensure_ascii=False),origin_trace=_invocation_origin))]\n except TypeError as e:\n raise TypeError(f\"Failed to convert tool use result to ContentBlock: {e}. Tools must return json serializable objects. or a list of ContentBlocks.\")\n # XXX: Need to support images and other content types somehow. We should look for images inside of the the result and then go from there.\n # try:\n # content_results = coerce_content_list(result)\n # except ValueError as e:\n\n # TODO: poolymorphic validation here is important (cant have tool_call or formatted_response in the result)\n # XXX: Should we put this coercion here or in the tool call/result area.\n for c in content_results:\n assert not c.tool_call, \"Tool call in tool result\"\n # assert not c.formatted_response, \"Formatted response in tool result\"\n if c.parsed:\n # Warning: Formatted response in tool result will be converted to text\n # TODO: Logging needs to produce not print.\n print(f\"Warning: Formatted response in tool result will be converted to text. Original: {c.parsed}\")\n c.text = _lstr(c.parsed.model_dump_json(),origin_trace=_invocation_origin)\n c.parsed = None\n assert not c.audio, \"Audio in tool result\"\n return ToolResult(tool_call_id=_tool_call_id, result=content_results), _invocation_api_params, {}\n else:\n return result, _invocation_api_params, {}\n # ... other code\n # ... other code",
"filepath": "src\\ell\\lmp\\tool.py",
"metadata": {
"file_path": "src\\ell\\lmp\\tool.py",
"file_name": "tool.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 723,
"span_ids": [
"tool"
],
"start_line": 19,
"end_line": 79,
"community": null
},
"node_id": "lmp\\tool.py::2"
},
{
"id": "lmp\\tool.py::3",
"input_type": "file",
"content": "def tool(*, exempt_from_tracking: bool = False, **tool_kwargs):\n def tool_decorator(fn: Callable[..., Any]) -> InvocableTool:\n # ... other code\n\n\n wrapper.__ell_tool_kwargs__ = tool_kwargs\n wrapper.__ell_func__ = _under_fn\n wrapper.__ell_type__ = LMPType.TOOL\n wrapper.__ell_exempt_from_tracking = exempt_from_tracking\n\n # Construct the pydantic mdoel for the _under_fn's function signature parameters.\n # 1. Get the function signature.\n\n sig = inspect.signature(fn)\n\n # 2. Create a dictionary of field definitions for the Pydantic model\n fields = {}\n for param_name, param in sig.parameters.items():\n # Skip *args and **kwargs\n if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):\n continue\n\n # Determine the type annotation\n if param.annotation == inspect.Parameter.empty:\n raise ValueError(f\"Parameter {param_name} has no type annotation, and cannot be converted into a tool schema for OpenAI and other provisders. Should OpenAI produce a string or an integer, etc, for this parameter?\")\n annotation = param.annotation\n\n # Determine the default value\n default = param.default\n\n # Check if the parameter has a Field with description\n if isinstance(param.default, FieldInfo):\n field = param.default\n fields[param_name] = (annotation, field)\n elif param.default != inspect.Parameter.empty:\n fields[param_name] = (annotation, param.default)\n else:\n # If no default value, use Field without default\n fields[param_name] = (annotation, Field(...))\n\n # 3. Create the Pydantic model\n model_name = f\"{fn.__name__}\"\n ParamsModel = create_model(model_name, **fields)\n\n # Attach the Pydantic model to the wrapper function\n wrapper.__ell_params_model__ = ParamsModel\n\n # handle tracking last.\n if exempt_from_tracking:\n ret = wrapper\n else:\n ret= _track(wrapper)\n\n # Helper function to get the Pydantic model for the tool\n def get_params_model():\n return wrapper.__ell_params_model__\n\n # Attach the helper function to the wrapper\n wrapper.get_params_model = get_params_model\n ret.get_params_model = get_params_model\n return ret\n\n return tool_decorator",
"filepath": "src\\ell\\lmp\\tool.py",
"metadata": {
"file_path": "src\\ell\\lmp\\tool.py",
"file_name": "tool.py",
"file_type": "text/x-python",
"category": "implementation",
"tokens": 510,
"span_ids": [
"tool"
],
"start_line": 82,
"end_line": 139,
"community": null
},
"node_id": "lmp\\tool.py::3"
},
{
"id": "lmp\\tool.py::4",
"input_type": "file",
"content": "tool.__doc__ = \"\"\"Defines a tool for use in language model programs (LMPs) that support tool use.\n\nThis decorator wraps a function, adding metadata and handling for tool invocations.\nIt automatically extracts the tool's description and parameters from the function's\ndocstring and type annotations, creating a structured representation for LMs to use.\n\n:param exempt_from_tracking: If True, the tool usage won't be tracked. Default is False.\n:type exempt_from_tracking: bool\n:param tool_kwargs: Additional keyword arguments for tool configuration.\n:return: A wrapped version of the original function, usable as a tool by LMs.\n:rtype: Callable\n\nRequirements:\n\n- Function must have fully typed arguments (Pydantic-serializable).\n- Return value must be one of: str, JSON-serializable object, Pydantic model, or List[ContentBlock].\n- All parameters must have type annotations.\n- Complex types should be Pydantic models.\n- Function should have a descriptive docstring.\n- Can only be used in LMPs with @ell.complex decorators\n\nFunctionality:\n\n1. Metadata Extraction:\n - Uses function docstring as tool description.\n - Extracts parameter info from type annotations and docstring.\n - Creates a Pydantic model for parameter validation and schema generation.\n\n2. Integration with LMs:\n - Can be passed to @ell.complex decorators.\n - Provides structured tool information to LMs.\n\n3. Invocation Handling:\n - Manages tracking, logging, and result processing.\n - Wraps results in appropriate types (e.g., _lstr) for tracking.\n\nUsage Modes:\n\n1. Normal Function Call:\n - Behaves like a regular Python function.\n - Example: result = my_tool(arg1=\"value\", arg2=123)\n\n2. LMP Tool Call:\n - Used within LMPs or with explicit _tool_call_id.\n - Returns a ToolResult object.\n - Example: result = my_tool(arg1=\"value\", arg2=123, _tool_call_id=\"unique_id\")\n\nResult Coercion:\n\n- String \u2192 ContentBlock(text=result)\n- Pydantic BaseModel \u2192 ContentBlock(parsed=result)\n- List[ContentBlock] \u2192 Used as-is\n- Other types \u2192 ContentBlock(text=json.dumps(result))\n\nExample::\n\n @ell.tool()\n def create_claim_draft(\n claim_details: str,\n claim_type: str,\n claim_amount: float,\n claim_date: str = Field(description=\"Date format: YYYY-MM-DD\")\n ) -> str:\n '''Create a claim draft. Returns the created claim ID.'''\n return \"12345\"\n\n # For use in a complex LMP:\n @ell.complex(model=\"gpt-4\", tools=[create_claim_draft], temperature=0.1)\n def insurance_chatbot(message_history: List[Message]) -> List[Message]:\n # Chatbot implementation...\n\n x = insurance_chatbot([\n ell.user(\"I crashed my car into a tree.\"),\n ell.assistant(\"I'm sorry to hear that. Can you provide more details?\"),\n ell.user(\"The car is totaled and I need to file a claim. Happened on 2024-08-01. total value is like $5000\")\n ]) \n print(x)\n '''ell.Message(content=[\n ContentBlock(tool_call(\n tool_call_id=\"asdas4e\",\n tool_fn=create_claim_draft,\n input=create_claim_draftParams({\n claim_details=\"The car is totaled and I need to file a claim. Happened on 2024-08-01. total value is like $5000\",\n claim_type=\"car\",\n claim_amount=5000,\n claim_date=\"2024-08-01\"\n })\n ))\n ], role='assistant')'''\n \n if x.tool_calls:\n next_user_message = response_message.call_tools_and_collect_as_message()\n # This actually calls create_claim_draft\n print(next_user_message)\n '''\n ell.Message(content=[\n ContentBlock(tool_result=ToolResult(\n tool_call_id=\"asdas4e\",\n result=[ContentBlock(text=\"12345\")]\n ))\n ], role='user')\n '''\n y = insurance_chatbot(message_history + [x, next_user_message])\n print(y)\n '''\n ell.Message(\"I've filed that for you!\", role='assistant')\n '''\n\nNote:\n- Tools are integrated into LMP calls via the 'tools' parameter in @ell.complex.\n- LMs receive structured tool information, enabling understanding and usage within the conversation context.\n \"\"\"",